diff --git a/OWNERS b/OWNERS index 7b721dd643e3399d29ff649e2f76182de72421a4..2e949debf181a6e75fdb5b1e1e091ce7a39c7e69 100644 --- a/OWNERS +++ b/OWNERS @@ -13,35 +13,14 @@ approvers: - kun_8 - binghamhuang reviewers: -- leo920320 -- wo-wenjie -- ma-dongfang -- wuyulong11 -- alysongirl -- wangchao285 -- brightlyking -- chenhao_1209 -- feng123www -- zhang-mingyu-0813 -- snowflakephoenix -- Seanesmhxocism -- augboost -- fanxiaotong1995 -- sunboquan -- kun_8 -- Martin-M -- ly-qianxiao -- yang-minghai22 -- hu-xiao-bo - lv-kaimeng - litian_drinksnow -- blian -- cycoe -- machj -- zhengweifeng6 -- gong-siwei -- uniteone - binghamhuang -- wjchuee -- zhou-xianqi -- stby11 \ No newline at end of file +- wo-wenjie +- ly-qianxiao +- leo920320 +- sunboquan +- stby +- Seanesmhxocism +- TAJh +- czr9775 \ No newline at end of file diff --git a/README.md b/README.md index 014a4d59f07116c2519de1b2a69463484549642d..7c28869d371debd863409bd8dbe75baf586a5c7e 100644 --- a/README.md +++ b/README.md @@ -1,96 +1,84 @@ -# 变更通知 +# 🚨 重要通知 -原Ascend Training Tools工具更名为MindStudio Training Tools,MindStudio训练工具链。变更计划如下: +**1. Ascend Training Tools 更名为 MindStudio Training Tools (mstt)。** -1. 2024.06.25本代码仓名称变更为mstt。 -2. 2024.07.04 URL变更为[https://gitee.com/ascend/mstt](https://gitee.com/ascend/mstt),原始URL仍然可用,但建议使用新URL。 +**2. 本代码仓 URL 变更为 [https://gitee.com/ascend/mstt](https://gitee.com/ascend/mstt),原 URL 仍然可用(2024.07.04 )。** -# MindStudio Training Tools +**3. 不再维护:[api_accuracy_checker](./debug/accuracy_tools/api_accuracy_checker/) (2024.09.30下线)和[ ptdbg_ascend](./debug/accuracy_tools/ptdbg_ascend/) +(2024.09.30下线)** -MindStudio Training Tools,MindStudio训练工具链。针对训练&大模型场景,提供端到端命令行&可视化调试调优工具,帮助用户快速提高模型开发效率。 +**相关目录 mstt/debug/accuracy_tools/api_accuracy_checker 和 mstt/debug/accuracy_tools/ptdbg_ascend 将于 2024.09.30 删除。新版本的预检和 ptdbg 已经合到 mstt/debug/accuracy_tools/msprobe 目录下。** -## 模型训练迁移全流程 -![输入图片说明](debug/resources/model_training_migration_process.png) +--- -## 使用说明 +# 🧰 MindStudio Training Tools -### [分析迁移工具](https://gitee.com/ascend/mstt/wikis/工具介绍/分析迁移工具/分析迁移工具介绍) +![Build Status](https://img.shields.io/badge/build-passing-brightgreen) +![Commit Activity](https://img.shields.io/badge/commit%20activity-high-red) +![License: Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue) + +## [分析迁移工具](https://gitee.com/ascend/mstt/wikis/工具介绍/分析迁移工具/分析迁移工具介绍) 1. [脚本分析工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E5%88%86%E6%9E%90%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) - 脚本分析工具提供分析脚本,帮助用户在执行迁移操作前,分析基于GPU平台的PyTorch训练脚本中算子、三方库套件、亲和API分析以及动态shape的支持情况。 + 脚本分析工具可以帮助用户在执行迁移操作前,分析基于 GPU 平台的 PyTorch 训练脚本中算子、三方库套件、API 亲和性以及动态 shape 的支持情况。 2. [(推荐)自动迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%87%AA%E5%8A%A8%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) - 自动迁移只需在训练脚本中导入库代码即可完成模型脚本迁移,使用方式较简单,且修改内容最少。 + 自动迁移工具只需在训练脚本中导入库代码即可完成模型脚本的迁移,使用方式简单,且修改内容少。 3. [脚本迁移工具](https://gitee.com/ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%84%9A%E6%9C%AC%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) - 脚本迁移工具提供后端命令行用于将GPU上训练的PyTorch脚本迁移至NPU上,得到新的训练脚本用于训练。 + 脚本迁移工具通过后端命令行,将 GPU 上训练的 PyTorch 脚本迁移至 NPU 上,得到新的训练脚本用于训练。 4. [训推一体权重转换工具](https://gitee.com/Ascend/mstt/wikis/%E5%B7%A5%E5%85%B7%E4%BB%8B%E7%BB%8D/%E5%88%86%E6%9E%90%E8%BF%81%E7%A7%BB%E5%B7%A5%E5%85%B7/%E8%AE%AD%E6%8E%A8%E4%B8%80%E4%BD%93%E6%9D%83%E9%87%8D%E8%BD%AC%E6%8D%A2%E5%B7%A5%E5%85%B7%E4%BD%BF%E7%94%A8%E6%8C%87%E5%AF%BC) - 训推一体权重转换工具,支持在GPU和NPU上训练好的模型转成加速推理支持的格式。 - -### [精度工具](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools) + 训推一体权重转换工具,支持在 GPU 和 NPU 上训练好的模型转成加速推理支持的格式。 -1. [api_accuracy_checker(Ascend模型精度预检工具)](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/api_accuracy_checker) +## [精度工具](./debug/accuracy_tools/) - 在昇腾NPU上扫描用户训练模型中所有API,进行API复现,给出精度情况的诊断和分析。 +[MindStudio Probe(msprobe,MindStudio 精度调试工具)](./debug/accuracy_tools/msprobe)。 -2. [ptdbg_ascend(PyTorch精度工具)](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/ptdbg_ascend) +## [性能工具](./profiler) - 进行PyTorch整网API粒度的数据dump、精度比对和溢出检测,从而定位PyTorch训练场景下的精度问题。 +1. [compare_tools(性能比对工具)](./profiler/compare_tools) -### [性能工具](https://gitee.com/ascend/mstt/tree/master/profiler) + 提供 NPU 与 GPU 性能拆解功能以及算子、通信、内存性能的比对功能。 -1. [compare_tools(性能比对工具)](https://gitee.com/ascend/mstt/tree/master/profiler/compare_tools) +2. [cluster_analyse(集群分析工具)](./profiler/cluster_analyse) - 提供NPU与GPU性能拆解功能以及算子、通信、内存性能的比对功能。 + 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合 MindStudio Insight 的集群分析功能使用。 -2. [cluster_analyse(集群分析工具)](https://gitee.com/ascend/mstt/tree/master/profiler/cluster_analyse) +3. [advisor](./profiler/advisor) - 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合Ascend Insight的集群分析功能使用。 + 将 Ascend PyTorch Profiler 或者 msprof 采集的 PyTorch 场景性能数据进行分析,并输出性能调优建议。 -3. [affinity_cpu_bind (亲和性cpu绑核工具) ](https://gitee.com/ascend/mstt/tree/master/profiler/affinity_cpu_bind) +4. [bind_core](./profiler/affinity_cpu_bind) - 提供亲和性CPU绑核能力,改善host_bound调度问题。 + 绑核脚本,支持非侵入修改工程代码,实现一键式绑核功能。 -### [Tensorboard](https://gitee.com/ascend/mstt/tree/master/plugins/tensorboard-plugins/tb_plugin) +## [Tensorboard](./plugins/tensorboard-plugins/tb_plugin) -Tensorboard支持NPU性能数据可视化插件PyTorch Profiler TensorBoard NPU Plugin。 +Tensorboard 支持 NPU 性能数据可视化插件 PyTorch Profiler TensorBoard NPU Plugin。 -支持将Ascend平台采集、解析的Pytorch Profiling数据可视化呈现,也兼容GPU数据采集、解析可视化。 +支持将 Ascend 平台采集、解析的 PyTorch Profiling 数据可视化呈现,也兼容 GPU 数据采集、解析可视化。 ## 分支维护策略 -MindStudio Training Tools工具版本分支的维护阶段如下: - -| **状态** | **时间** | **说明** | -| ------------------- | -------- | ------------------------------------------------ | -| 计划 | 1—3 个月 | 计划特性 | -| 开发 | 3个月 | 开发特性 | -| 维护 | 6—12个月 | 合入所有已解决的问题并发布版本 | -| 无维护 | 0—3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 | -| 生命周期终止(EOL) | N/A | 分支不再接受任何修改 | - -## 现有分支的维护状态 - -MindStudio Training Tools分支版本号命名规则如下: - -mstt仓每年发布4个版本,每个版本都将对应一个分支;以v6.0为例,其将对应v6.0.RC1、v6.0.RC2、v6.0.RC3以及v6.0.0四个版本,在仓库中将存在与之对应的分支。 - -| **分支** | **状态** | **发布日期** | **后续状态** | **EOL日期** | -| ------------- | -------- | ------------ | ------------------------ | ----------- | -| **v6.0.0** | 维护 | 2023/12/12 | 预计2024/12/12起无维护 | | +1. MindStudio Training Tools 工具版本分支的维护阶段如下: -## 参与贡献 + | **状态** | **时间** | **说明** | + | ------------------- | -------- | ------------------------------------------------ | + | 计划 | 1—3 个月 | 计划特性 | + | 开发 | 3个月 | 开发特性 | + | 维护 | 6—12个月 | 合入所有已解决的问题并发布版本 | + | 无维护 | 0—3 个月 | 合入所有已解决的问题,无专职维护人员,无版本发布 | + | 生命周期终止(EOL) | N/A | 分支不再接受任何修改 | -1. Fork 本仓库 -2. 新建 xxx 分支 -3. 提交代码 -4. 新建 Pull Request +2. MindStudio Training Tools 分支版本号命名规则如下: -## 版本过渡提示 + mstt 仓每年发布 4 个版本,每个版本都将对应一个分支;以 v6.0 为例,其将对应 v6.0.RC1、v6.0.RC2、v6.0.RC3 以及 v6.0.0 四个版本,在仓库中将存在与之对应的分支。 -当前版本预检和ptdbg维护到2024/09/30,准备于2024/09/30下线,相关目录mstt/debug/accuracy_tools/api_accuracy_checker和mstt/debug/accuracy_tools/ptdbg_ascend将于2024/09/30删除。新版本的预检和ptdbg已经合到mstt/debug/accuracy_tools/atat目录下。 + | **分支** | **状态** | **发布日期** | **后续状态** | **EOL日期** | + | ------------- | -------- | ------------ | ------------------------ | ----------- | + | **v6.0.0** | 维护 | 2023.12.12 | 预计 2024.12.12 起无维护 | | diff --git a/debug/OWNERS b/debug/OWNERS index 09121722c9d7147133c6f111cd10b279979ebdb3..36e09821f37f6abed2a9da211ff7c1ef447218b9 100644 --- a/debug/OWNERS +++ b/debug/OWNERS @@ -5,7 +5,13 @@ approvers: - kun_8 - binghamhuang - brightlyking +- litian_drinksnow reviewers: - lv-kaimeng -- litian_drinksnow - binghamhuang +- xiangsen2 +- TAJh +- jiandaobao +- pengxiaopeng1 +- zhengxinqian +- louyujing diff --git a/debug/accuracy_tools/MANIFEST.in b/debug/accuracy_tools/MANIFEST.in index 7242c0c95627b56620b63c650dbbffbf8aaa2896..1af064685bcca89be479a465fd8b4f466047165f 100644 --- a/debug/accuracy_tools/MANIFEST.in +++ b/debug/accuracy_tools/MANIFEST.in @@ -1,2 +1,4 @@ -recursive-include atat/ * -recursive-exclude atat/test * \ No newline at end of file +include README.md +include LICENSE +recursive-include msprobe * +recursive-exclude msprobe/test * \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template new file mode 100644 index 0000000000000000000000000000000000000000..7630839aa937c6d0419629b5e93c34b51b71f295 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template @@ -0,0 +1,325 @@ +import json +import os +import math +from enum import Enum, auto +import torch +try: + import torch_npu +except ImportError: + pass + + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] +RAISE_PRECISION = {{ + "torch.float16": torch.float32, + "torch.half": torch.float32, + "torch.bfloat16": torch.float32, + "torch.float32": torch.float64, + "torch.float": torch.float64 +}} + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + + +def get_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch_npu.npu.is_available(): + device = torch.device("npu") + else: + raise Exception("Error: This device is not NPU or GPU!") + return device + + +def generate_bool_tensor(low, high, shape): + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape) + bool_tensor = torch.gt(tensor, 0) + return bool_tensor + + +def generate_numerical_tensor(low, high, shape, data_dtype): + if data_dtype in TORCH_FLOAT_TYPE: + scale = high - low + rand01 = torch.rand(shape, dtype=eval(data_dtype)) + tensor = rand01 * scale + low + elif data_dtype in TORCH_INT_TYPE: + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) + else: + raise NotImplementedError(f"{{data_dtype}} is not supported!") + if torch.numel(tensor) == 0: + return tensor + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + data = tmp_tensor.reshape(shape) + return data + + +def generate_random_tensor(info): + low, high = info.get('Min'), info.get('Max') + data_dtype = info.get('dtype') + shape = tuple(info.get('shape')) + if data_dtype == "torch.bool": + data = generate_bool_tensor(low, high, shape) + else: + data = generate_numerical_tensor(low, high, shape, data_dtype) + return data + + +def generate_real_tensor(data_path): + data_path = os.path.realpath(data_path) + data = torch.load(data_path) + return data + + +def generate_data(info): + data_type = info.get("type") + data_path = info.get("datapath") + if data_type in TENSOR_DATA_LIST: + if data_path: + data = generate_real_tensor(data_path) + else: + data = generate_random_tensor(info) + else: + data = info.get("value") + return data + + +def get_input(): +{args_element_assignment} + args_device = [{args_list_generator_device}] + args_bench = [{args_list_generator_bench}] +{kwargs_value_assignment} + kwargs_device = {{{kwargs_dict_generator_device}}} + kwargs_bench = {{{kwargs_dict_generator_bench}}} + return args_device, kwargs_device, args_bench, kwargs_bench + + +def exec_api_device(args, kwargs): + output_device = {api_type}.{api_name}(*args, **kwargs) + return output_device + + +def exec_api_bench(args, kwargs): + output_bench = {api_type}.{api_name}(*args, **kwargs) + return output_bench + + +def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol): + out_bench = out_bench.to(out_device.dtype) + min = torch.finfo(out_device.dtype).min + max = torch.finfo(out_device.dtype).max + bench_clip = torch.clamp(out_bench, min=min, max=max) + device_clip = torch.clamp(out_device, min=min, max=max) + clipped_abs_ae = torch.abs(device_clip - bench_clip) + clipped_re = clipped_abs_ae / abs_bench_with_eps + pass_mask = torch.less_equal(clipped_re, rtol) + both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip)) + pass_mask = torch.logical_or(pass_mask, both_nan_mask) + not_pass_mask = torch.logical_not(pass_mask) + not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask) + inf_nan_err_cnt = torch.sum(not_pass_mask) + return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask) + + +def compute_rmse(abs_err, normal_value_mask): + if torch.sum(normal_value_mask) == 0: + return 0 + else: + masked_ae = torch.where(normal_value_mask, abs_err, 0) + mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask) + rmse = torch.sqrt(mse) + return rmse + + +def compute_error_balance(out_device, out_bench): + larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0)) + smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0)) + total_count = torch.numel(out_bench) + error_balance = abs(larger_count - smaller_count) / total_count + return error_balance + + +def compare_tensor(out_device, out_bench, api_name): + if out_device.shape != out_bench.shape: + print("ERROR: shape of out_device and out_bench is not equal!") + return None + if torch.numel(out_bench) == 0: + print("Both out_device and out_bench have zero elements.") + return None + print(f"shape is {{out_bench.shape}}") + print(f"dtype of out_device is {{out_device.dtype}}") + print(f"dtype of out_bench is {{out_bench.dtype}}") + dtype_device = out_device.dtype + dtype_bench = out_bench.dtype + if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \ + or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \ + or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE: + out_device = out_device.to(torch.device("cpu")) + if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD: + print("compare standard: binary equality standard:") + error_number = torch.sum(out_device != out_bench).item() + error_rate = error_number / torch.numel(out_bench) + print(f"error rate is {{error_rate}}.") + else: + abs_err = torch.abs(out_device - out_bench) + abs_bench = torch.abs(out_bench) + if dtype_bench == torch.float32: + eps = 2 ** -23 + if dtype_bench == torch.float64: + eps = 2 ** -52 + abs_bench_with_eps = abs_bench + eps + rel_err = torch.abs(abs_err / abs_bench_with_eps) + device_finite_mask = torch.isfinite(out_device) + bench_finite_mask = torch.isfinite(out_bench.to(dtype_device)) + both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask) + inf_nan_mask = torch.logical_not(both_finite_mask) + if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD: + if dtype_device == torch.float16: + rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5 + elif dtype_device == torch.bfloat16: + rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5 + else: + rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9 + small_value_mask = torch.less_equal(abs_bench, small_value) + small_value_mask = torch.logical_and(small_value_mask, both_finite_mask) + normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask)) + inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol) + rel_err_mask = torch.greater(rel_err, rtol) + rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask) + if torch.sum(normal_value_mask) == 0: + rel_err_proportion = 0 + else: + rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask) + abs_err_mask = torch.greater(abs_err, small_value_atol) + abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask) + if torch.sum(small_value_mask) == 0: + abs_err_proportion = 0 + else: + abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask) + print("compare standard: absolute threshold standard") + print(f"relative error ratio is {{rel_err_proportion}}") + print(f"absolute error ratio is {{abs_err_proportion}}") + elif compare_standard == CompareStandard.ULP_ERROR_STANDARD: + if dtype_device == torch.float16: + min_eb, exponent_num = -14, 10 + elif dtype_device == torch.bfloat16: + min_eb, exponent_num = -126, 7 + else: + min_eb, exponent_num = -126, 23 + eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench))) + eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape)) + if dtype_device == torch.float32: + ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64) + else: + ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32) + ulp_err = torch.abs(ulp_err) + max_ulp_err = torch.max(ulp_err) + mean_ulp_err = torch.mean(ulp_err) + if dtype_device == torch.float32: + ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench) + else: + ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench) + print("compare standard: ulp error standard") + print(f"maximum ulp error is {{max_ulp_err}}") + print(f"mean ulp error is {{mean_ulp_err}}") + print(f"ulp error proportion is {{ulp_err_proportion}}") + else: + if dtype_device == torch.float16: + small_value, small_value_atol = 1.0e-3, 1.0e-5 + elif dtype_device == torch.bfloat16: + small_value, small_value_atol = 1.0e-3, 1.0e-5 + else: + small_value, small_value_atol = 1.0e-6, 1.0e-9 + small_value_mask = torch.less_equal(abs_bench, small_value) + small_value_mask = torch.logical_and(small_value_mask, both_finite_mask) + normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask)) + abs_err_mask = torch.greater(abs_err, small_value_atol) + abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask) + if torch.sum(small_value_mask) == 0: + small_value_err_proportion = 0 + else: + small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask) + rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape)) + if torch.max(rel_err) >= 0: + max_rel_err = torch.max(rel_err) + else: + max_rel_err = 0 + if torch.sum(normal_value_mask) == 0: + mean_rel_err = 0 + else: + mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask) + rmse = compute_rmse(abs_err, normal_value_mask) + error_balance = compute_error_balance(out_device, out_bench) + print("compare standard: benchmark standard") + print(f"small value error proportion is {{small_value_err_proportion}}") + print(f"maximum relative error is {{max_rel_err}}") + print(f"mean relative error is {{mean_rel_err}}") + print(f"root mean squared error is {{rmse}}") + print(f"error balance is {{error_balance}}") + else: + print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.") + return None + + +def compare_element(out_device, out_bench, api_name): + if type(out_device) != type(out_bench): + print("ERROR: out_device and out_bench is not the same type!") + return None + if isinstance(out_bench, torch.Tensor): + print(f"data type: {{type(out_bench)}}") + compare_tensor(out_device, out_bench, api_name) + elif isinstance(out_bench, (bool, int, float, str)): + print(f"data type: {{type(out_bench)}}") + if out_device == out_bench: + print("PASS: out_device and out_bench equals.") + else: + print("ERROR: out_device and out_bench is not equal!") + else: + print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.") + return None + + +def compare(out_device, out_bench, api_name): + print("Compare result:") + if type(out_device) != type(out_bench): + print("ERROR: out_device and out_bench is not the same type!") + print("Compare finished.") + return None + if isinstance(out_bench, (list, tuple)): + print(f"data type: {{type(out_bench)}}") + if len(out_device) != len(out_bench): + print("ERROR: len of out_device and out_bench is different!") + print("Compare finished.") + return None + for index, _ in enumerate(out_bench): + print(f"index {{index}}:") + compare_element(out_device[index], out_bench[index], api_name) + else: + compare_element(out_device, out_bench, api_name) + print("Compare finished.") + + +device = get_device() +api_name = "{api_name}" +compare_standard = {compare_standard} +torch.manual_seed({random_seed}) +for i in range({iter_times}): + print(f"iter: {{i}}:") + args_device, kwargs_device, args_bench, kwargs_bench = get_input() + output_device = exec_api_device(args_device, kwargs_device) + output_bench = exec_api_bench(args_bench, kwargs_bench) + compare(output_device, output_bench, api_name) diff --git a/debug/accuracy_tools/atat/mindspore/__init__.py b/debug/accuracy_tools/atat/mindspore/__init__.py deleted file mode 100644 index bb3f93567542e93ff913edf3daabcd3aedb91ee3..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/atat/mindspore/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from atat.mindspore.debugger.precision_debugger import PrecisionDebugger diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/compare/test_acc_compare.py b/debug/accuracy_tools/atat/test/pytorch_ut/compare/test_acc_compare.py deleted file mode 100644 index 5a82289a0003a332e41d9202b63e7bde4bc43c42..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/atat/test/pytorch_ut/compare/test_acc_compare.py +++ /dev/null @@ -1,17 +0,0 @@ -# coding=utf-8 -import unittest -from atat.pytorch.compare.acc_compare import rename_api - -class TestUtilsMethods(unittest.TestCase): - - def test_rename_api(self): - test_name_1 = "Distributed.broadcast.0.forward.input.0" - expect_name_1 = "Distributed.broadcast.input.0" - actual_name_1 = rename_api(test_name_1, "forward") - self.assertEqual(actual_name_1, expect_name_1) - - test_name_2 = "Torch.sum.0.backward.output.0" - expect_name_2 = "Torch.sum.output.0" - actual_name_2 = rename_api(test_name_2, "backward") - self.assertEqual(actual_name_2, expect_name_2) - \ No newline at end of file diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py deleted file mode 100644 index fa52fe0e1b05701cec9c8cdf41fe1586029c826e..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py +++ /dev/null @@ -1,38 +0,0 @@ -from unittest import TestCase -from unittest.mock import patch, mock_open - -from atat.core.common.const import Const -from atat.pytorch.pt_config import parse_json_config - - -class TestPtConfig(TestCase): - def test_parse_json_config(self): - mock_json_data = { - "task": "statistics", - "dump_path": "./dump/", - "rank": [], - "step": [], - "level": "L1", - "seed": 1234, - "statistics": { - "scope": [], - "list": [], - "data_mode": ["all"], - }, - "tensor": { - "file_format": "npy" - } - } - with patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \ - patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ - patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data): - common_config, task_config = parse_json_config(None, None) - self.assertEqual(common_config.task, Const.STATISTICS) - self.assertEqual(task_config.data_mode, ["all"]) - - with patch("atat.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \ - patch("atat.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ - patch("atat.pytorch.pt_config.json.load", return_value=mock_json_data): - common_config, task_config = parse_json_config(None, Const.TENSOR) - self.assertEqual(common_config.task, Const.STATISTICS) - self.assertEqual(task_config.file_format, "npy") diff --git a/debug/accuracy_tools/graph_analyzer/README.md b/debug/accuracy_tools/graph_analyzer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf0ad58c363d88e20787de5ea20d19df6da40074 --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/README.md @@ -0,0 +1,53 @@ +# Graph Analyzer + +#### 介绍 +图分析精度工具 + +#### 软件架构 +软件架构说明 + + +#### 安装教程 + +1. 下载源码 +``` +git clone https://gitee.com/ascend/mstt.git -b poc +``` +2. pip安装 +``` +cd debug/accuracy_tools/graph_analyzer +pip install . +``` + +#### 使用说明 +- ir图使用推荐: + - ir图推荐使用`anf_after_graph_build`图 + +#### 功能说明 +**`<>`表示必选参数,`[]`表示可选参数** +1. ir图结构分析 +使用方式: + +``` +graph_analyzer --ir [--output ] +``` +执行后,会自动分析ir文件,将ir文件分析后的结果输出到指定output目录下的struct.json,如果未指定output则默认为当前目录 + + +2. 数码关联 +数码关联是指数据和代码调用栈的关联,数据一般意义上指静态图`O0`,`O1`,`O2`下dump下来的数据 +目前支持: +- [x] 全量[tensor(npy)]数据格式的数码关联 +- [x] 统计值[statisitc]数据格式的数码关联 +- [x] 融合算子场景 +- [x] 支持超长算子名dump文件的自动解析 +- [x] 反向算子的正向绑定 + +使用方式: + +``` +graph_analyzer --ir --data [--output ] +``` + +- 如果是全量模式,则会把数据文件路径和代码调用栈的关联关系存到output路径下的mapping.csv中 +- 如果是统计值模式,则会把统计值csv中每个条目加上该条目对应的代码栈 diff --git a/debug/accuracy_tools/atat/__init__.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/__init__.py rename to debug/accuracy_tools/graph_analyzer/graph_analyzer/__init__.py diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/bind.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/bind.py new file mode 100644 index 0000000000000000000000000000000000000000..5958b618fa1c4eb46ce9d1785dc35794ebaeeb9b --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/bind.py @@ -0,0 +1,159 @@ +import os +import logging +import glob +from typing import Dict, List +from pathlib import Path +import pandas as pd +from graph_analyzer.graph import GraphNode + + +# 定义Trie节点 +class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_key = False + self.value = None + +# 定义Trie树 +class Trie: + def __init__(self): + self.root = TrieNode() + + # 向Trie中插入一个键 + def insert(self, key, value): + node = self.root + for char in key: + if char not in node.children: + node.children[char] = TrieNode() + node = node.children[char] + # 标记结束位置 + node.is_end_of_key = True + node.value = value + + # 在name字符串中查找所有匹配的键 + def search_in_string(self, string): + matched_values = [] + for i in range(len(string)): + node = self.root + j = i + # 从字符串的每个字符开始,逐字符查找匹配 + while j < len(string) and string[j] in node.children: + node = node.children[string[j]] + if node.is_end_of_key: + matched_values.append(node.value) + j += 1 + return matched_values + +# 定义匹配函数 +def match_codes(trie, name): + matched_nodes = trie.search_in_string(name) + matched_codes = ['\n'.join(ii.code_info) for ii in matched_nodes] + return '\n'.join(matched_codes) + + +def match_names(trie, name): + matched_nodes = trie.search_in_string(name) + matched_names = [ii.scope for ii in matched_nodes] + return '\n'.join(matched_names) + + +def complex_map(df, match_dict): +# 构建Trie树并插入所有键 + trie = Trie() + for key, value in match_dict.items(): + trie.insert(key, value) + + df['Code Stack'] = df['Op Name'].apply(lambda name: match_codes(trie, name)) + df['Scope Name'] = df['Op Name'].apply(lambda name: match_names(trie, name)) + return df + + +def find_npy_files(npy_path): + npy_files = [] + # 检查当前路径是否是一个以 .npy 结尾的文件 + if npy_path.endswith('npy') and os.path.isfile(npy_path): + npy_files.append(Path(npy_path).resolve()) + return npy_files + + npy_files = list(Path(npy_path).rglob('*.npy')) + return npy_files + + +def write_to_csv(param: Dict, output_dir: str, append: bool): + # 打开CSV文件以写入模式 + os.makedirs(output_dir, exist_ok=True) + file_name = os.path.join(output_dir, "code.csv") + data = [(name, res1, res2) for name, (res1, res2) in param.items()] + df = pd.DataFrame(data, columns=['File Path', 'Code Stacks', 'Scope Name']) + # 如果 append 为 True 并且文件已经存在,追加写入 + if append and os.path.exists(file_name): + # 清洗数据,筛选掉空字符串 + df = df[(df['Code Stacks'] != '') | (df['Scope Name'] != '')] + df.to_csv(file_name, mode='a', header=False, index=False) + # 否则,覆盖写入或者文件不存在时正常写入 + else: + df.to_csv(file_name, mode='w', header=True, index=False) + + +def find_statistic_files(directory): + if not os.path.isdir(directory): + return [] + pattern = os.path.join(directory, '**', "statistic.csv") + statistic_files = list(glob.glob(pattern)) + return statistic_files + + + +def bind_for_statistic(statistic_files: List[str], match_dict: Dict): + for statistic_file in statistic_files: + df = pd.read_csv(statistic_file) + df = complex_map(df, match_dict) + logging.info("Processing %s completed, code stack saved in %s", statistic_file, statistic_file) + df.to_csv(statistic_file, index=False) + + +def bind_code_info_for_data(input_dir: str, nodes: Dict[str, GraphNode]) -> Dict[str, str]: + # 待重构后优化性能 + match_dict = {} + for node in nodes.values(): + # 屏蔽子图节点 + if node.is_subgraph: + continue + # 获取规范化后的scope name + scope_name = node.scope.replace("/", "_") + match_dict[scope_name] = node + npy_files = find_npy_files(input_dir) + + bind_result = {} + if not npy_files: + statistic_files = find_statistic_files(input_dir) + if statistic_files: + bind_for_statistic(statistic_files, match_dict) + return bind_result + + for npy_file in npy_files: + directory, file_name = os.path.split(npy_file) # 拆分路径 + name_without_ext = os.path.splitext(file_name)[0] # 提取文件名(去掉扩展名) + if '.' not in name_without_ext: + # 3. 读取find.csv文件 + csv_file_path = os.path.join(directory, 'mapping.csv') + df = pd.read_csv(csv_file_path, header=None) + + # 4. 查找是否有与xxx.npy匹配的条目 + matching_row = df[df[0] == file_name] # 假设A列存储文件名 + if not matching_row.empty: + corresponding_name = matching_row[1].values[0] + logging.info("The corresponding name in column B is: %s", corresponding_name) + else: + corresponding_name = None + logging.info("No entry found for %s in find.csv.", file_name) + name_without_ext = os.path.splitext(corresponding_name)[0] + npy_path = os.path.realpath(npy_file) + node_scope = name_without_ext.split(".")[1] + trie = Trie() + for key, value in match_dict.items(): + trie.insert(key, value) + bind_code = match_codes(trie, node_scope) + bind_name = match_names(trie, node_scope) + bind_result[npy_path] = (bind_code, bind_name) + return bind_result diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..564f56c861f847de5f82f33d2cc91295a0871862 --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph.py @@ -0,0 +1,117 @@ +from typing import List, Dict, Union +from collections import defaultdict, deque + +class GraphNode: + def __init__(self, name: str, pos: int = -1, unique_name: str = "", operator_name: str = "", return_variable: str = "", return_value: str = "", + var_inputs: List[str] = None, has_constant_input: bool = False, unique_id: str="", scope: str = "", code_info: List[str] = None, + is_subgraph: bool = False, attrs: Union[Dict[str, str], List[str]] = None): + self.name = name + self.unique_name = unique_name + self.pos = pos + self.operator_name = operator_name + self.return_variable = return_variable + self.return_value = return_value + self.var_inputs = var_inputs if var_inputs else [] + self.has_constant_input = has_constant_input + self.unique_id = unique_id + self.scope = scope + self.code_info = code_info if code_info else [] + self.attrs = attrs if attrs else ({} if not is_subgraph else []) + self.nodes = {} # Internal nodes if this is a subgraph + self.predecessors = [] # Predecessor nodes + self.successors = [] # Successor nodes + self.is_subgraph = is_subgraph + + def trace_back_ancestors(self, ancestors: List[str], visited: Dict[str, bool], parser) -> None: + if visited[self.unique_name]: + return + visited[self.unique_name] = True + ancestors.append(self.unique_name) + for predecessor in self.predecessors: + predecessor.trace_back_ancestors(ancestors, visited, parser) + + +class Graph: + def __init__(self, nodes): + self.nodes = set(nodes.values()) + + def topological_sort(self): + # 创建邻接表和入度表 + nodes = self.nodes + in_degree = {node: len(node.predecessors) for node in nodes} + + # 初始化队列,将所有入度为 0 的节点加入队列 + queue = deque([node for node in nodes if in_degree[node] == 0]) + topo_order = [] + + # Kahn算法的拓扑排序 + while queue: + node = queue.popleft() + topo_order.append(node) + + for successor in node.successors: + in_degree[successor] -= 1 + if in_degree[successor] == 0: + queue.append(successor) + + return topo_order + + def find_independent_nodes(self, subset_nodes): + # 获取整个图的拓扑排序 + + topo_order = self.topological_sort() + + # 将子集节点记录为集合,方便查找 + subset_set = set(subset_nodes) + + # 追踪哪些子集节点有被访问过 + visited = set() + + # 筛选出不被其他子集节点依赖的节点 + independent_nodes = [] + + # 按照拓扑排序遍历 + for node in topo_order: + if node in subset_set: + # 如果该节点在子集中,检查它是否已经被访问 + if node not in visited: + independent_nodes.append(node) + # 将该节点指向的所有邻居标记为访问过(被依赖过) + for successor in node.successors: + if successor in subset_set: + visited.add(successor) + return independent_nodes + +def find_boundary_nodes(nodes, domain_level): + domain_structure = defaultdict(lambda: {'boundary': {'upper': set(), 'lower': set()}, 'nodes': set()}) + + for node in nodes: + if node.scope.startswith("Gradient"): + continue + node_new_scope = node.scope.split('/') + if domain_level <= len(node_new_scope) - 1: # 确保不使用最后一级 + current_domain = '/'.join(node_new_scope[:domain_level]) + domain_structure[current_domain]['nodes'].add(node) + + for domain, data in domain_structure.items(): + # 遍历域内的节点,寻找上边界和下边界 + for node in data['nodes']: + if not node.operator_name.startswith("Prim"): + continue + node_scope = node.scope.split('/') + for succ in node.successors: + succ_scope = succ.scope.split('/') + if succ.scope.startswith("Gradient") or len(succ_scope) == 2: + continue + if (succ.operator_name != "Param" and succ.operator_name != "Constant") and node_scope[:domain_level] != succ_scope[:domain_level]: + data['boundary']['lower'].add(node.name) + for pred in node.predecessors: + pred_scope = pred.scope.split('/') + if (pred.operator_name != "Param" and pred.operator_name != "Constant") and node_scope[:domain_level] != pred_scope[:domain_level]: + data['boundary']['upper'].add(node.name) + + # 递归处理子域 + sub_nodes = [node for node in data['nodes'] if len(node.scope) > domain_level] + if sub_nodes: + domain_structure[domain].update(find_boundary_nodes(sub_nodes, domain_level + 1)) + return domain_structure diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph_parser.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..376473ca0ea444df978bfa81c5a96ccb7c3ff55a --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/graph_parser.py @@ -0,0 +1,214 @@ +import re +import logging +from typing import Tuple, List, Dict +from graph_analyzer.graph import GraphNode + + +class Parser: + def __init__(self): + self.nodes = {} + self.local_dict = {} + self.number_dict = {} + + @staticmethod + def parse_subgraph_attributes(text: str, subgraph_node: GraphNode, start_pos: int, end_pos: int) -> None: + subgraph_attr_pattern = re.compile(r'subgraph attr:\s*(.*)', re.DOTALL) + match = subgraph_attr_pattern.search(text, start_pos, end_pos) + if match: + attrs = match.group(1).strip().split('\n') + if isinstance(subgraph_node.attrs, list): + subgraph_node.attrs.extend(attrs) + + @staticmethod + def parse_graph_attributes(text: str, graph_node: GraphNode) -> None: + attr_pattern = re.compile(r'# Attrs:\s*(.*)', re.DOTALL) + match = attr_pattern.search(text, graph_node.pos) + if match: + attrs = match.group(1).strip().split('\n') + for attr in attrs: + if not attr: # if end line + break + key, value = attr.split(':') + if isinstance(graph_node.attrs, dict): + graph_node.attrs[key.strip()] = value.strip() + + @staticmethod + def parse_code_info(text: str, start_pos: int, end_pos: int) -> List[str]: + code_info = [] + code_info_pattern = re.compile(r'# .*', re.MULTILINE) + final_pos = end_pos if end_pos else len(text) - 1 + lines = text[start_pos + 1:final_pos].split('\n') + for line in lines: + match = code_info_pattern.search(line) + if not match: + break + code_info.append(match.group(0).strip('# ').strip('/')) + return code_info + + @staticmethod + def extract_bracket_content(text: str, start_pos: int) -> Tuple[str, int]: + stack = [] + content = [] + for i in range(start_pos, len(text)): + char = text[i] + if char == '(': + stack.append('(') + elif char == ')': + stack.pop() + if not stack: + content.append(char) + return ''.join(content), i + content.append(char) + raise ValueError("Mismatched parentheses") + + # check ok + @staticmethod + def find_matching_brace(text: str, start_pos: int) -> int: + stack = [] + for i in range(start_pos, len(text)): + if text[i] == '{': + stack.append('{') + elif text[i] == '}': + stack.pop() + if not stack: + return i + raise ValueError("Matching closing brace not found") + + # check ok + @staticmethod + def extract_constants(inputs_str: str) -> List[str]: + constant_pattern = re.compile(r'\b(\w+\(.*?\))') + constants = constant_pattern.findall(inputs_str) + return constants + + def parse_func_graph(self, text: str) -> None: + func_graph_pattern = re.compile(r'# IR entry: @(\S+)') + matches = func_graph_pattern.finditer(text) + for match in matches: + func_name = match.group(1) + func_graph_info = GraphNode(name=func_name, pos=match.start(), is_subgraph=False) + self.nodes[func_name] = func_graph_info + + def parse_nodes(self, text: str, subgraph_info: GraphNode) -> None: + node_pattern = re.compile(r'(%\d+)\((\S+)\)\s*=\s*(\S+)\(') + matches = list(node_pattern.finditer(text)) + for i, match in enumerate(matches): + series_number = match.group(1) + variable_name = match.group(2) + operator_name = match.group(3) + unique_name = "&".join([series_number, variable_name]) + self.local_dict[series_number] = unique_name + + args_str, end_pos = self.__class__.extract_bracket_content(text, match.end() - 1) + inputs = re.findall(r'%\w+', args_str) + subgraph_inputs = re.findall(r'@\w+', args_str) + inputs += subgraph_inputs + + constants = self.__class__.extract_constants(args_str) + + scope_pattern = re.compile(r'# .*scope.*:\s*\((.*?)\)', re.IGNORECASE | re.MULTILINE) + # [^:]scope[^:]:\s*\((.*?)\) + scope_match = scope_pattern.search(text, end_pos) + scope = scope_match.group(1) if scope_match else "" + + id_pattern = re.compile(r'.*cnode_primal_attrs:\s*\{.*\b(?:forward_unique_id|unique_id):\s*\"(\d+)\".*', re.IGNORECASE) + unique_id_match = id_pattern.search(text, end_pos, scope_match.start()) + unique_id = unique_id_match.group(1) if unique_id_match else None + + if scope: + next_match = matches[i + 1].start() - 1 if i < len(matches) - 1 else None + code_info = self.__class__.parse_code_info(text, scope_match.end(), next_match) + else: + code_info = None + + node_info = GraphNode(name=variable_name, unique_name=unique_name, operator_name=operator_name, var_inputs=inputs + constants, unique_id=unique_id, scope=scope, code_info=code_info) + + if unique_id and scope and not scope.startswith("Gradients"): + self.number_dict[unique_id] = node_info + + if subgraph_info: + subgraph_info.nodes[variable_name] = node_info # 这里不用unique_name会有事吗 + + if not self.nodes.get(unique_name, None): + self.nodes[unique_name] = node_info + else: + pass + + for const in constants: + if const not in self.nodes: + const_node = GraphNode(name=const, operator_name="Constant", var_inputs=[], has_constant_input=True) + if not self.nodes.get(const_node, None): + self.nodes[const] = const_node + if subgraph_info: + subgraph_info.nodes[const] = const_node + self.local_dict[const] = const + + for input_var in node_info.var_inputs: + if input_var in self.local_dict or input_var in self.nodes: + input_name = self.local_dict.get(input_var, input_var) # 没有就用原来名字 + input_node = self.nodes.get(input_name, None) + if input_node: + node_info.predecessors.append(input_node) + input_node.successors.append(node_info) + else: + param_node = GraphNode(name=input_var, operator_name="Param", var_inputs=[], has_constant_input=False) + if not self.nodes.get(input_var, None): + self.nodes[input_var] = param_node + node_info.predecessors.append(param_node) + param_node.successors.append(node_info) + + # check ok + def extract_callees(self, text: str) -> None: + for node_info in self.nodes.values(): + func_start_pos = node_info.pos + func_end_pos = text.find('}', func_start_pos) + func_text = text[func_start_pos:func_end_pos] + callee_pattern = re.compile(r'Partial\(@(\S+)\(') + callee_matches = callee_pattern.finditer(func_text) + for callee_match in callee_matches: + callee_name = callee_match.group(1) + if callee_name not in node_info.var_inputs: + node_info.var_inputs.append(callee_name) + + # check ok + def parse_subgraphs(self, text: str) -> None: + subgraph_pattern = re.compile(r'subgraph\s+@(\S+)(\([^\)]*\))?\s+.*\{') + matches = list(subgraph_pattern.finditer(text)) + end_pos = 0 + for match in matches: + last_pos = end_pos + 2 + subgraph_name = match.group(1).split('(')[0] + start_pos = match.start() + end_pos = self.__class__.find_matching_brace(text, start_pos) + subgraph_text = text[start_pos:end_pos + 1] + attr_text = text[last_pos:start_pos] + subgraph_info = GraphNode(name=subgraph_name, pos=start_pos, is_subgraph=True) + self.nodes[subgraph_name] = subgraph_info + self.__class__.parse_subgraph_attributes(text, subgraph_info, last_pos, start_pos) + self.parse_nodes(subgraph_text, subgraph_info) + subgraph_info.end = end_pos + logging.info('Parsed subgraph: %s', subgraph_name) + + # check ok + def count_nodes(self) -> Tuple[int, int]: + total_nodes = len(self.nodes) + total_cnodes = sum(1 for node in self.nodes.values() if node.name.startswith('CNode')) + return total_nodes, total_cnodes + + # check ok + def create_backward_map(self): + for node in self.nodes.values(): + if node.scope and node.scope.startswith("Gradients"): + related_forward_node = self.number_dict.get(node.unique_id, None) + if related_forward_node: + node.code_info = related_forward_node.code_info + + def parse(self, text: str) -> None: + self.parse_func_graph(text) + self.parse_subgraphs(text) + self.parse_nodes(text, None) + self.extract_callees(text) + self.create_backward_map() + + def get_nodes(self) -> Dict[str, GraphNode]: + return self.nodes diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/main.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/main.py new file mode 100644 index 0000000000000000000000000000000000000000..18df68f01670018fb989e3b35eaaa5658a14461a --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/main.py @@ -0,0 +1,23 @@ +import os +import re +import sys +import argparse +from typing import List +from graph_analyzer.processor import process + + + +def main(): + parser = argparse.ArgumentParser(description="IR Parser") + parser.add_argument('--ir', type=str, required=True, help="Path to the graph file") + parser.add_argument('--data', type=str, required=False, default=None, help="Path to data dir") + parser.add_argument('--node-list', type=List, required=False, default=None, help="Error node list") + parser.add_argument('--output', type=str, required=False, default="./", help="Path to output dir") + parser.add_argument('--append', action='store_true', help="Whether to append to the CSV file if it exists") + args = parser.parse_args() + + process(args) + + +if __name__ == "__main__": + main() diff --git a/debug/accuracy_tools/graph_analyzer/graph_analyzer/processor.py b/debug/accuracy_tools/graph_analyzer/graph_analyzer/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..370008da591df2566cfc8d37b8bda9dfe5b57bff --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/graph_analyzer/processor.py @@ -0,0 +1,45 @@ +import os +import stat +import json +from graph_analyzer.graph import Graph, find_boundary_nodes +from graph_analyzer.graph_parser import Parser +from graph_analyzer.bind import bind_code_info_for_data, write_to_csv + +def serialize_domain_structure(domain_structure): + serialized_structure = {} + for domain, data in domain_structure.items(): + serialized_structure[domain] = { + 'boundary': {'upper': list(data['boundary']['upper']), 'lower': list(data['boundary']['lower'])}, + 'nodes': [node.name for node in data['nodes']] + } + # 递归处理子域,避免解析 boundary 和 nodes 部分 + for key in data: + if key not in ['boundary', 'nodes', 'upper', 'lower']: + serialized_structure[domain][key] = serialize_domain_structure({key: data[key]}) + return serialized_structure + +def process(args): + ir_file_path = args.ir + with open(ir_file_path, 'r') as f: + input_text = f.read() + + parser = Parser() + parser.parse(input_text) + + nodes = parser.get_nodes() + graph = Graph(nodes) + + if args.data: + bind_result = bind_code_info_for_data(args.data, nodes) + if bind_result: + # 将 append 参数传递给 write_to_csv + write_to_csv(bind_result, args.output, args.append) + + domain_structure = find_boundary_nodes(nodes.values(), 1) + output_structure = serialize_domain_structure(domain_structure) + output_file = os.path.join(args.output, "struct.json") + + # 使用 os.open() 指定文件权限 + fd = os.open(output_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) # 0o600 表示读写 + with open(fd, "w") as f: + json.dump(output_structure, f, indent=4) diff --git a/debug/accuracy_tools/graph_analyzer/setup.py b/debug/accuracy_tools/graph_analyzer/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4fbb52fb0ef997eb0339fb49c6393fd33f9019 --- /dev/null +++ b/debug/accuracy_tools/graph_analyzer/setup.py @@ -0,0 +1,26 @@ +import os +from setuptools import setup, find_packages + +setup( + name='graph_analyzer', + version='1.0.0', + packages=find_packages(), + install_requires=[ + ], + entry_points={ + 'console_scripts': [ + 'graph_analyzer=graph_analyzer.main:main', # Allows running `graph_analyzer` + ], + }, + author='TAJh', + description='Graph Analyzer used for graph analysis and dump data analysis', + long_description=open('README.md', encoding='utf-8').read() if os.path.exists('README.md') else '', + long_description_content_type='text/markdown', + url='https://gitee.com/tajh/graph_analyzer.git', # Replace with the correct URL + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + ], + python_requires='>=3.6', +) diff --git a/debug/accuracy_tools/kj600/README.md b/debug/accuracy_tools/monitor/README.md similarity index 54% rename from debug/accuracy_tools/kj600/README.md rename to debug/accuracy_tools/monitor/README.md index bd97acf6dc9b35b21d23d58ff1e6cb56a82e92fb..d9a1171f4ea9237931f704e5b4cf5b6a68642dbb 100644 --- a/debug/accuracy_tools/kj600/README.md +++ b/debug/accuracy_tools/monitor/README.md @@ -1,4 +1,4 @@ -# TensorProbe (codename:kj600) 模型训练状态监控工具 +# Monitor 模型训练状态监控工具 ## 简介 @@ -10,7 +10,7 @@ | 依赖软件 | |-------------| -| torch | +| torch>=2.0 | | torch_npu | | torchvision | | tensorboard | @@ -18,25 +18,126 @@ | sqlalchemy | | pymysql | -### 2. 安装 kj600 +### 2. 安装 monitor -方式一:从 git 直接安装 +方式一:下载源码安装 ``` -pip install git+https://gitee.com/xiangsen2/kj600.git +git clone -b poc https://gitee.com/ascend/mstt.git +cd mstt/debug/accuracy_tools/monitor +pip install . ``` -方式二:下载源码安装 +## 快速上手 +### 梯度监控 +模型训练状态的异常通常会反映在loss和梯度上,通过对模型各个模块梯度的监控,可以帮助快速定位异常的第一现场。 + +1. 输出目录 +监控结果写入tensorboard的event文件/csv中,设置输出路径(默认为`monitor_output`,通过环境变量配置) +```bash +export MONITOR_OUTPUT_DIR=/xxx/output_dir ``` -git clone https://gitee.com/xiangsen2/kj600.git -cd kj600 -pip install . + +2. 在训练脚本中使能工具(Megatron-LM) + +``` +from monitor.module_hook import TrainerMon +hooker = TrainerMon("./monitor_config.json", process_group=None, params_have_main_grad=True) + +model, optimizer, opt_param_scheduler = setup_model_and_optimizer( + model_provider, model_type) +# 模型、优化器初始化后使能工具 + +hooker.monitor_gnorm_with_ad( + model, grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size, optimizer=optimizer, dp_group=mpu.get_data_parallel_group(), tp_group=mpu.get_tensor_model_parallel_group()) + + +# 可以在任意位置获取当前的梯度统计量, 不同调用位置不能保证reduce已完成 +reduced, unreduced = hooker.generate_wgrad_metrics() +``` + + +| 字段名字 | 是否必选 | 解释 | +| ------------------------------------------------------------ | -------- | -------- | +|"grad_acc_steps"| 必选 |梯度累积的步数,当micro step=grad acc steps时,会触发反向hook获取模型梯度| +|"optimizer"| 可选 |各种并行域reduce后的梯度在opt.step前获取,数据写入在step后进行。默认patch pytorch的优化器,传入其他优化器(如MegatronOptimizer)可以调整工具行为,如clip_grad发生在megatron的优化器中,pytorch的优化器之前。| +|"dp_group"| 可选 |训练过程中的dp_group。dp域通信后,group内所有rank的梯度相同,落盘数据冗余。提供dp_group后,工具仅保留每个dp_group的第一个rank的梯度| +|"tp_group"| 可选 |训练过程中的tp_group。tp域通信后,group内部分参数所有rank的梯度相同,落盘数据冗余。提供tp_group后,工具仅保留每个tp_group中冗余参数在第一个rank的梯度。当前适配Megatron core_v0.6.0, 通过权重属性`tensor_model_parallel`判断是否冗余| + +3. 在json文件中配置工具 +``` +{ + "targets": { + "module": {}, + "module.module.language_model.encoder.layers.0": {"input_grad":"tuple[1]:0", "output_grad":"tuple[2]:0"} + }, + "print_struct": false, # 若不了解模型结构,可以打开print_struct打印模型结构 + "module_ranks": [0,1,2,3], # 需要监控的rank + "wg_distribution": true, + "format": "csv", # 如果不需要落盘文件,设置为 "api" + "ops": ["norm", "min", "max", "mean"], + "eps": 1e-8, + "ndigits: 6 +} +``` + +4. 结果验证 +训练日志中通常会打屏一个训练步的grad norm。提供了脚本校验落盘数据和打屏信息的一致性。 +```bash +python monitor/unittest/test_monitor.py -m monitor_output/Aug13_02-27-5 -l logs/train_gpt3_TP2_PP1_CP1_monitor.log -d 2 -t 2 +``` +`-m`指定落盘csv的路径前缀。`-l`指定训练日志。脚本通过关键词`grad norm: `匹配训练日志中的grad norm,根据实际情况修改。从落盘数据计算的grad norm和日志中的grad norm相对偏差超过1%,会有警告。`-d`、`--dp_size`声明data parallel size,`-t`、`--tp_size`声明tensor paralllel size。 +示例输出: +```txt +rank 2 is duplicated in dp group +rank 3 is duplicated in dp group +grad norm in consiste between training log and reduced gradients monitored +grad mean is in consisten between unreduced grad and reduced grad monitord. +``` +需要提供并行相关参数,具体参见: +```bash +python monitor/unittest/test_monitor.py -h +``` +### 梯度异常时序判断 +0. 训练前配置相关参数 +工具支持自动判断训练过程中的梯度异常,需要在配置文件中设置alert相关字段。`AnomalyTurbulence`会将当前数值与历史均值比较,如果相对偏差超过阈值,会在打屏信息中提示用户。如果打开`dump`选项,则会将异常梯度相关信息落盘,用于后续时序判断。 +```json + "alert": { + "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], + "dump": true + }, +``` +1. 实例化工具时传入流水线并行group +```python +hooker = TrainerMon("./monitor_config.json", process_group=mpu.get_pipeline_model_parallel_group(), params_have_main_grad=True) +``` +照常开始训练 + +2. 进入工具路径启动异常分析脚本: +```shell +cd monitor/ +python3 anomaly_analyse.py -d $MONITOR_OUTPUT_DIR/anomaly_detected +``` +支持以下参数配置 +| 字段名字 | 解释 | 是否必选释 | +| ------ | -------- | -------- | +|-d 或 --data_path| 指定梯度异常落盘文件夹,梯度监控功能输出,一般为$MONITOR_OUTPUT_DIR/anomaly_detected。|是 | +|-o 或 --out_path| 排序后的异常落盘文件地址,默认在--data_path路径下落盘一个anomaly_analyse.json文件| 否 | +|-k 或 --topk| 指定保留前topk个异常,默认为8| 否 | +|-s 或 --step_list| 指定分析的step范围,默认为[]| 否 | + +## 已知问题 +- Megatron中使用流水线并行时,完成当前stage的计算并将output传递到下一个stage后,会调用`deallocate_output_tensor`释放output。当工具使能后,部分功能会给一些module注错反向hook,hook功能可能为output创建一个view副本,导致output内存无法释放。如果工具使能后出现如下报错,则需要跳过deallocate的步骤。在较新的megatron代码中,可以在`megatron/training/arguments.py`中将`kw_args['deallocate_pipeline_outputs']`设为False,或在`megatron/core/pipeline_parallel/schedules.py`中跳过`deallocate_output_tensor`的调用 +```bash +File "~/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 117, in deallocate_output_tensor + assert out._base is None, "counter-productive to free a view of another tensor." +AssertionError: counter-productive to free a view of another tensor. ``` -# 快速上手 +## 详细配置 - 下面以Ascend/ModelLink训练框架为例,给出kj600工具的使用方法。 + 下面以Ascend/ModelLink训练框架为例,给出monitor工具的使用方法。 1. 在ModelLink的根目录,创建json配置文件,如llama2_config.json,内容如下: @@ -54,8 +155,10 @@ pip install . "cc_distribution": {"enable":true, "cc_codeline":[]}, "alert": { "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}], - "inform": {"recipient": "database", "connection_str": "mysql+pymysql://username:password@host:port/database"} + "inform": {"recipient": "database", "connection_str": "mysql+pymysql://username:password@host:port/database"}, + "dump": true }, + "format": "tensorboard" "ops": ["min", "max", "norm", "zeros", "id"], "eps": 1e-8 } @@ -78,8 +181,9 @@ pip install . |"xy_distribution"| 可选 | 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。| |"mv_distribution"| 可选 | 若为true则会监控指定模块中的参数的优化器状态, 默认为false。需要在TrainerMon构造函数正确指定opt_ty. 目前只支持megatron的混合精度优化器以及megatron的分布式优化器。 Deepspeed的分布式优化器实现暂不支持。 | |"wg_distribution"| 可选 | 若为true则会监控指定模块的参数梯度, 默认为false。 | -|"alert"| 必选 | · "rules": 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence。 如果统计标量超出历史均值的指定浮动范围(threshold指定, 0.5意味着上浮或者下浮50%)则在控制台打印报警信息。
· "inform": 自动报警需要的配置,若想关闭自动报警删掉inform的配置即可。其中"recipient"指定自动报警的通知方式,可选值为"database"或"email",默认为"database"。
- 若"recipient"为"database",则需要指定"connection_str"字段,即数据库的连接URL,默认为{"recipient":"database", "connection_str": "mysql+pymysql://username:password@host:port/database"},若有特殊字符需要转义。
- 若"recipient"为"email",则需要指定"send_email_address"-发送方邮箱地址,"receive_email_address"-接收方邮箱地址,"send_email_username"-发送方邮箱用户名,"send_email_password"-发送方邮箱密码,"smtp_server"-发送方邮箱对应的SMTP服务器,"smtp_port"-发送方邮箱对应的SMTP端口号。默认为:
{"recipient":"email", send_email_address": "sender@huawei.com", "receive_email_address": "receiver@huawei.com", "send_email_username": "username", "send_email_password": "******", "smtp_server": "smtpscn.huawei.com", "smtp_port": "587"}| -|"cc_distribution"| 可选 | 其中“enable”字段控制开关;需要监控通信算子时,务必尽量早地实例化`TrainerMon`, 因为监控通过劫持原始func后挂hook实现,部分加速库初始化时会保存原始function,避免监控失效。“cc_codeline”字段指定监控的代码行,如:`train.py\\[23\\]`,默认为空列表,不特别指定;"cc_pre_hook"字段控制是否监控通信前的数据; "cc_log_only"为true时,仅记录调用到的算子及其调用栈, 不监控通信的输入输出| +|"alert"| 必选 | · "rules": 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence。 如果统计标量超出历史均值的指定浮动范围(threshold指定, 0.5意味着上浮或者下浮50%)则在控制台打印报警信息。
· "inform": 自动报警需要的配置,若想关闭自动报警删掉inform的配置即可。其中"recipient"指定自动报警的通知方式,可选值为"database"或"email",默认为"database"。
- 若"recipient"为"database",则需要指定"connection_str"字段,即数据库的连接URL,默认为{"recipient":"database", "connection_str": "mysql+pymysql://username:password@host:port/database"},若有特殊字符需要转义。
- 若"recipient"为"email",则需要指定"send_email_address"-发送方邮箱地址,"receive_email_address"-接收方邮箱地址,"send_email_username"-发送方邮箱用户名,"send_email_password"-发送方邮箱密码,"smtp_server"-发送方邮箱对应的SMTP服务器,"smtp_port"-发送方邮箱对应的SMTP端口号。默认为:
{"recipient":"email", send_email_address": "sender@h*****.com", "receive_email_address": "receiver@h*****.com", "send_email_username": "username", "send_email_password": "******", "smtp_server": "smtpscn.h*****.com", "smtp_port": "587"}| +|"cc_distribution"| 可选 | 其中"enable"字段控制通信监控模块的开关;需要监控通信算子时,务必尽量早地实例化`TrainerMon`, 因为监控通过劫持原始func后挂hook实现,部分加速库初始化时会保存原始function,避免监控失效。"cc_codeline"字段指定监控的代码行,如:`train.py\\[23\\]`,默认为空列表,不特别指定;"cc_pre_hook"字段控制是否监控通信前的数据; 模块会在第二个optimize.step之前打印通信日志,包括通信api的调用栈、输入dtype、通信group。 "cc_log_only"为true时,仅打印日志,不监控通信的输入输出,并在打印后中断训练。可以根据通信日志设置"cc_codeline",规避与训练过程不相关的通信,比如一些时间、metrics的同步。| +|"format"| 可选 | 数据落盘格式,默认为tensorboard,支持可选 "csv"。 | |"ops"| 可选 |与ur_distribution、xy_distribution、mv_distribution、wg_distribution、mg_direction、cc_distribution配合,监控所选张量的min、max、norm、zeros值。其中,zeros代表监控所选张量的元素小于eps的比例,id代表监控所选的非张量本身,默认为[]。| |"eps"| 可选 |若ops里包含"zeros"则需要配置,默认为1e-8。| @@ -109,33 +213,36 @@ pip install . } ``` -2. 在训练器中加入代码,开启kj600训练监控。 +2. 在训练器中加入代码,开启monitor训练监控。 例如在ModelLink/pretrain_gpt.py的model_provider GPTModel构造后加入以下代码, **注意优化器类型opt_ty** : ``` - from kj600.module_hook import TrainerMon - hooker = TrainerMon("./llama2_config.json", params_have_main_grad=True, opt_ty="Megatron_DistributedOptimizer") # or opt_ty=Megatron_Float16OptimizerWithFloat16Params + from monitor.module_hook import TrainerMon + hooker = TrainerMon("./llama2_config.json", process_group=None, params_have_main_grad=True, opt_ty="Megatron_DistributedOptimizer") # or opt_ty=Megatron_Float16OptimizerWithFloat16Params hooker.hook_modules(model=model, grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size) ``` params_have_main_grad: 若为True则参数权重梯度为main_grad,否则为grad,默认为True。 如果不是Megatron-LM的训练框架, 可以设置对应的梯度累积步数grad_acc_steps。 - 如果要监控混合精度优化器的动量和方差, 需要在混合精度优化器构造后加入如下代码。 目前只支持Megatron_DistributedOptimizer, 使用bf16或者fp16混合精度时开启分布式优化器。 或者Megatron_Float16OptimizerWithFloat16Params, 使用bf16或者fp16混合精度选项并且不开启分布式优化器。 + 如果要监控优化器的动量和方差,需要在优化器构造后加入如下代码。 目前支持Megatron实现的优化器: + - Megatron_FP32OptimizerMon,普通优化器。 + - Megatron_Float16OptimizerWithFloat16Params, 使用bf16或者fp16混合精度选项并且不开启分布式优化器。 + - Megatron_DistributedOptimizer, 使用bf16或者fp16混合精度时开启分布式优化器。 ``` model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider, model_type) # 插入位置 - from kj600.module_hook import TrainerMon + from monitor.module_hook import TrainerMon TrainerMon.set_wrapped_optimizer(optimizer) ``` 3. 配置tensorboard写入的目录 ``` - export KJ600_OUTPUT_DIR=/xxx/output_dir + export MONITOR_OUTPUT_DIR=/xxx/output_dir ``` 4. 开始预训练,在日志中如果发现以下内容, 则说明指定的模块被成功监视。 @@ -148,7 +255,7 @@ pip install . 5. 训练过程中,打开tensorboard,可以查看训练的中间状态: ``` -tensorboard --logdir=$KJ600_OUTPUT_DIR +tensorboard --logdir=$MONITOR_OUTPUT_DIR ``` 之后,运行以下SSH命令来建立端口转发,可以在本地通过http://localhost:6006访问tensorboard: @@ -171,6 +278,7 @@ TrainerMon.__init__(config_file_path, params_have_main_grad=True, opt_ty=None) - | 参数 | 说明 | 是否必选 | | ----- | -------------------- | -------- | | config_file_path |自己写的json配置文件路径。 | 是 | +| process_group | 传入ProcessGroup对象,用以确定pipeline并行不同rank异常间时序,megatron下通过core.parallel_state.get_pipeline_model_parallel_group()获得 | 否 | | params_have_main_grad |权重是否使用main_grad,是就为True,否则为False。默认为True。 | 否 | | opt_ty |优化器类型,有两个选项,Megatron_DistributedOptimizer:使用bf16或者fp16混合精度时开启分布式优化器;Megatron_Float16OptimizerWithFloat16Params:使用bf16或者fp16混合精度选项并且不开启分布式优化器,也适用于常规的adam优化器。如果使用的不是adam优化器,使用None。默认为None。 | 否 | diff --git a/debug/accuracy_tools/kj600/img/cpu_info.png b/debug/accuracy_tools/monitor/img/cpu_info.png similarity index 100% rename from debug/accuracy_tools/kj600/img/cpu_info.png rename to debug/accuracy_tools/monitor/img/cpu_info.png diff --git a/debug/accuracy_tools/kj600/img/train.png b/debug/accuracy_tools/monitor/img/train.png similarity index 100% rename from debug/accuracy_tools/kj600/img/train.png rename to debug/accuracy_tools/monitor/img/train.png diff --git a/debug/accuracy_tools/kj600/img/train_with_kj600.png b/debug/accuracy_tools/monitor/img/train_with_kj600.png similarity index 100% rename from debug/accuracy_tools/kj600/img/train_with_kj600.png rename to debug/accuracy_tools/monitor/img/train_with_kj600.png diff --git a/debug/accuracy_tools/atat/mindspore/debugger/__init__.py b/debug/accuracy_tools/monitor/monitor/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/mindspore/debugger/__init__.py rename to debug/accuracy_tools/monitor/monitor/__init__.py diff --git a/debug/accuracy_tools/monitor/monitor/anomaly_analyse.py b/debug/accuracy_tools/monitor/monitor/anomaly_analyse.py new file mode 100644 index 0000000000000000000000000000000000000000..aad42a84cef30b735f4caecf9e6d9d7fb58394de --- /dev/null +++ b/debug/accuracy_tools/monitor/monitor/anomaly_analyse.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import argparse +import ast +import fcntl +import heapq +import json +import os +from pathlib import Path +import sys + +from monitor.utils import print_info_log, print_warn_log +from monitor.anomaly_detect import GradAnomalyData +from monitor.file_check import ( + change_mode, + check_link, + FileCheckConst, + check_path_before_create, + FileChecker, + FileOpen, +) + +ANOMALY_JSON = "anomaly.json" +ANALYSE_JSON = "anomaly_analyse.json" + +class AnomalyDataWriter: + """ + 异常数据写入类,负责将异常数据写入到JSON文件中。 + """ + + def __init__(self, dump_path, rank) -> None: + self.dump_path = dump_path + self.dump_rank_dir = os.path.join(self.dump_path, f"rank{rank}") + self.json_path = os.path.join(self.dump_rank_dir, ANOMALY_JSON) + + @staticmethod + def get_anomaly_dict(anomalies): + """将GradAnomalyData列表转换为json""" + anomalies_json = {} + for anomaly in anomalies: + anomalies_json.update({anomaly.get_key(): anomaly.to_dict()}) + return anomalies_json + + @staticmethod + def update_data_in_single_json(json_path, anomalies_data): + with FileOpen(json_path, "w+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(anomalies_data, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def init_detected_json(self): + """初始化落盘文件""" + check_path_before_create(self.dump_path) + if not os.path.exists(self.dump_path): + Path(self.dump_path).mkdir( + mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True + ) + file_check = FileChecker(self.dump_path, FileCheckConst.DIR) + file_check.common_check() + + if not os.path.exists(self.dump_rank_dir): + Path(self.dump_rank_dir).mkdir( + FileCheckConst.DATA_DIR_AUTHORITY, parents=True, exist_ok=True + ) + + if os.path.exists(self.json_path): + file_check = FileChecker( + self.json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE + ) + file_check.common_check() + print_warn_log(f"The existing file will be deleted: {self.json_path}.") + os.remove(self.json_path) + Path(self.json_path).touch() + change_mode(self.json_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def write_detected_json(self, anomalies): + """ + 落盘异常数据 + Args: + anomalies: GradAnomalyData对象列表 + """ + anomalies_json = self.get_anomaly_dict(anomalies) + print_info_log(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.") + if Path(self.json_path).exists() and os.path.getsize(self.json_path) > 0: + with FileOpen(self.json_path, "r+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data_to_write = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + else: + data_to_write = {} + data_to_write.update(anomalies_json) + self.update_data_in_single_json(self.json_path, data_to_write) + + +class AnomalyDataLoader: + def __init__(self, data_path) -> None: + self.data_path = data_path + + @staticmethod + def create_instances_from_dict(anomalies_dict: dict): + instances = [] + for values in anomalies_dict.values(): + try: + instances.append(GradAnomalyData(**values)) + except KeyError as e: + print_warn_log(f"Missing key in anomaly data: {e}") + except ValueError as e: + print_warn_log( + f"Value error when creating a GradAnomalyData instance: {e}" + ) + return instances + + def get_anomalies_from_jsons(self): + """遍历文件夹,从rankK/anomaly.json中读取异常数据 + return: anomalies: GradAnomalyData对象列表 + """ + anomalies = [] + check_link(self.data_path) + for rank_dir in os.listdir(self.data_path): + rank_path = os.path.join(self.data_path, rank_dir) + if not os.path.isdir(rank_path): + continue + json_path = os.path.join(rank_path, ANOMALY_JSON) + if not os.path.exists(json_path): + continue + with FileOpen(json_path, "r+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data_anomalies = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + instances = self.create_instances_from_dict(data_anomalies) + anomalies.extend(instances) + return anomalies + + +class AnomalyAnalyse: + def __init__(self) -> None: + self.sorted_anomalies = [] + + def get_range_top_K(self, topk, step_list, anomalies): + """ + 获取前topk个step_list范围内的异常。 + """ + if not step_list: + filtered_anomalies = anomalies + else: + filtered_anomalies = [ + anomaly for anomaly in anomalies if anomaly.step in step_list + ] + if topk >= len(filtered_anomalies): + self.sorted_anomalies = sorted(filtered_anomalies) + else: + self.sorted_anomalies = list(heapq.nsmallest(topk, filtered_anomalies)) + return self.sorted_anomalies + + def rewrite_sorted_anomalies(self, output_path): + """ + 将排序后的异常数据重新落盘 + """ + file_check = FileChecker( + output_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE + ) + file_check.common_check() + + sorted_data = AnomalyDataWriter.get_anomaly_dict(self.sorted_anomalies) + print_info_log(f"{ANALYSE_JSON} is at {output_path}.") + json_path = os.path.join(output_path, ANALYSE_JSON) + if os.path.exists(json_path): + file_check = FileChecker( + json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE + ) + file_check.common_check() + print_warn_log(f"The existing file will be deleted: {json_path}.") + os.remove(json_path) + Path(json_path).touch() + change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) + AnomalyDataWriter.update_data_in_single_json(json_path, sorted_data) + + +def _get_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, + help=" The anomaly detect result dictionary: generate from monitor tool.", + required=True, + ) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The analyse task result out path.", + required=False, + ) + parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, + help=" Top K number of earliest anomalies.", + required=False, + ) + parser.add_argument("-s", "--step", dest="step_list", default=[], type=str, + help=" Analyse which steps.", + required=False, + ) + return parser.parse_args(sys.argv[1:]) + +def _get_step_and_stop(args): + try: + step_list = ast.literal_eval(args.step_list) + if not isinstance(step_list, list): + raise ValueError(f"{args.step_list} is not a list") + except (ValueError, SyntaxError, RecursionError) as e: + raise Exception( + f"The step list must be a resolvable list type" + ) from e + if args.top_k_number <= 0: + raise Exception("The top k number must be greater than 0.") + return step_list, args.top_k_number + +def _anomaly_analyse(): + args = _get_parse_args() + step_list, top_k_number = _get_step_and_stop(args) + loader = AnomalyDataLoader(args.data_path_dir) + anomalies = loader.get_anomalies_from_jsons() + analyser = AnomalyAnalyse() + top_anomalies = analyser.get_range_top_K( + top_k_number, step_list, anomalies + ) + analyser.rewrite_sorted_anomalies( + args.out_path if args.out_path else args.data_path_dir + ) + + print_info_log(f"Top {top_k_number} anomalies are listed as follows:") + for index, anomaly in enumerate(top_anomalies): + print_info_log(f"{index}: {anomaly.message}") + + +if __name__ == "__main__": + _anomaly_analyse() + print_info_log("Analyse task completed.") diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py b/debug/accuracy_tools/monitor/monitor/anomaly_detect.py similarity index 32% rename from debug/accuracy_tools/kj600/kj600/anomaly_detect.py rename to debug/accuracy_tools/monitor/monitor/anomaly_detect.py index cbd7b6daa2f0d9b0a9b28016993e836ee07df72d..4a5d9baf9c578324c2e93cedb44ed66551fa813f 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py +++ b/debug/accuracy_tools/monitor/monitor/anomaly_detect.py @@ -1,10 +1,16 @@ +import os +import sys import statistics as st from abc import ABC from typing import List -import sys -from torch.utils.tensorboard import SummaryWriter from collections import defaultdict -from kj600.utils import print_info_log +from dataclasses import dataclass, field +import pandas as pd +from torch.utils.tensorboard import SummaryWriter +from monitor.utils import print_info_log, check_file_valid_writable, make_file_safety, create_directory +from monitor.const import Const +from monitor.file_check import change_mode, FileCheckConst + class ScanRule(ABC): def apply(self, history, cur): @@ -59,15 +65,101 @@ class bcolors: BOLD = '\033[1m' UNDERLINE = '\033[4m' -class SummaryWriterWithAD(SummaryWriter): - def __init__(self, path, ad_rules, job_id, anomaly_inform=False): - super().__init__(path) +class AnomalyDataFactory(ABC): + def __init__(self, rank, pp_stage, group_mates): + super().__init__() + self.rank = rank + self.pp_stage = pp_stage + self.group_mates = group_mates + self.micro_step = 0 + self.vpp_stage = 0 + self.name2callid = {} + + def set_call_id(self, name2callid): + """根据当前GradContext信息更新call_id vpp_stage等信息 + """ + self.name2callid = name2callid + + def create(self, tag_name, message, step): + """如果检查出异常, 调用当前接口生成GradAnomalyData实例 + """ + param_name = tag_name.split('/')[0] + call_id = self.name2callid.get(param_name,-1) + if Const.vpp in param_name: + vpp_stage = int(param_name.lstrip(Const.vpp).split(Const.vpp_sep)[0]) + else: + vpp_stage = 0 + + return GradAnomalyData( + self.rank, + step, + self.micro_step, + self.pp_stage, + self.vpp_stage, + call_id, + tag_name, + message, + self.group_mates + ) + +@dataclass(eq=True) +class GradAnomalyData: + rank: int = 0 + step: int = 0 + micro_step: int = 0 + pp_stage: int = 0 + vpp_stage: int = 0 + call_id: int = 0 + tag_name: str = field(default=None, compare=False) + message: str = field(default="", compare=False) + group_mates: list = field(default=None, compare=False) + + def __lt__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + if self.step != other.step: + return self.step < other.step + if self.micro_step != other.micro_step: + return self.micro_step < other.micro_step + if self.pp_stage != other.pp_stage: + return self.pp_stage > other.pp_stage + if self.vpp_stage != other.vpp_stage: + return self.vpp_stage > other.vpp_stage + if self.call_id != other.call_id: + return self.call_id < other.call_id + return False + + def __le__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + return self == other or self < other + + def to_dict(self): + return self.__dict__ + + def get_key(self): + return ''.join( + (str(self.tag_name), "_step_", str(self.step), "_call_" , str(self.call_id))) + +class BaseWriterWithAD: + def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6): self.tag2scalars = defaultdict(list) self.ad_rules = ad_rules self.job_id = job_id self.anomaly_inform = anomaly_inform - - def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False): + self.anomaly_factory = anomaly_factory + self.anomalies = [] + self.ndigits = ndigits + + def get_anomalies(self): + """返回已检测到的异常列表 + """ + return self.anomalies + + def clear_anomalies(self): + self.anomalies.clear() + + def add_scalar(self, tag, scalar_value, global_step=None): new_avg = avg = scalar_value if tag in self.tag2scalars: N = len(self.tag2scalars[tag]) @@ -76,11 +168,64 @@ class SummaryWriterWithAD(SummaryWriter): self.tag2scalars[tag].append((scalar_value, new_avg)) detected, rule_name = self._ad(scalar_value, history=avg) if detected: - print_info_log(f"{bcolors.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}.{bcolors.ENDC}") - exception_message = f"{bcolors.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}.{bcolors.ENDC}" + exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + print_info_log(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}") if self.anomaly_inform: self.anomaly_inform.run(exception_message, self.job_id) - return super().add_scalar(tag, scalar_value, global_step, walltime, new_style, double_precision) - + + if self.anomaly_factory: + self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step)) + def _ad(self, scalar_value, history): return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value) + + +class CSVWriterWithAD(BaseWriterWithAD): + def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6): + super().__init__(path, ad_rules, job_id, anomaly_inform, anomaly_factory, ndigits) + + self.log_dir = path + create_directory(path) + self.context_dict = defaultdict(list) + self.header = [] + + def write_csv(self, prefix, step): + if len(self.context_dict) == 0: + return + filepath = os.path.join(self.log_dir, f'{prefix}_{step}.csv') + if not os.path.exists(filepath): + make_file_safety(filepath) + data_frame = pd.DataFrame(columns=self.header) + data_frame.to_csv(filepath, index=False) + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + check_file_valid_writable(filepath) + new_data = [] + for name, metric_value in self.context_dict.items(): + if Const.vpp not in name: + new_data.append([name]+metric_value) + else: + new_data.append(name.lstrip(Const.vpp).split(Const.vpp_sep)+metric_value) + new_data = pd.DataFrame(new_data) + new_data.to_csv(filepath, mode='a+', header=False, index=False) + self.context_dict = defaultdict(list) + + def add_scalar(self, tag, scalar_value, global_step): + super().add_scalar(tag, scalar_value, global_step) + + name = tag.split('/')[0] + self.context_dict[name].append(round(scalar_value, self.ndigits)) + + def close(self): + pass + +class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD): + def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6): + super(SummaryWriter, self).__init__(path, ad_rules, job_id, anomaly_inform, anomaly_factory, ndigits) + super().__init__(path) + change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY) + + def add_scalar(self, tag, scalar_value, global_step): + super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step) + return super().add_scalar(tag, scalar_value, global_step) + \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_inform.py b/debug/accuracy_tools/monitor/monitor/anomaly_inform.py similarity index 95% rename from debug/accuracy_tools/kj600/kj600/anomaly_inform.py rename to debug/accuracy_tools/monitor/monitor/anomaly_inform.py index 301ac769217943a36e5d4cbe06033c828e5c675e..fe0fdb3f105cf3f400eee91734e37048fbaa3ca0 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_inform.py +++ b/debug/accuracy_tools/monitor/monitor/anomaly_inform.py @@ -1,18 +1,17 @@ import smtplib from email.mime.text import MIMEText -import sqlite3 from datetime import datetime, timedelta -from kj600.database import Database, ExceptionMessage +from monitor.database import Database, ExceptionMessage # define class InformRegistry to get inform_sub_class class AnomalyInformFactory: @staticmethod def create_informer(**kwargs): - if kwargs['recipient'] == "database": + if kwargs.get('recipient') == "database": return DatabaseInform(**kwargs) - elif kwargs['recipient'] == "email": + elif kwargs.get('recipient') == "email": return EmailInform(**kwargs) else: raise ValueError("Invaild recipient specified") diff --git a/debug/accuracy_tools/kj600/kj600/config.json b/debug/accuracy_tools/monitor/monitor/config.json similarity index 100% rename from debug/accuracy_tools/kj600/kj600/config.json rename to debug/accuracy_tools/monitor/monitor/config.json diff --git a/debug/accuracy_tools/monitor/monitor/const.py b/debug/accuracy_tools/monitor/monitor/const.py new file mode 100644 index 0000000000000000000000000000000000000000..e4198a994228fb4e4a99ec8a3aaa5fe9b3d14dcd --- /dev/null +++ b/debug/accuracy_tools/monitor/monitor/const.py @@ -0,0 +1,4 @@ + +class Const: + vpp = "vpp" + vpp_sep = ':' \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/database.py b/debug/accuracy_tools/monitor/monitor/database.py similarity index 100% rename from debug/accuracy_tools/kj600/kj600/database.py rename to debug/accuracy_tools/monitor/monitor/database.py diff --git a/debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml b/debug/accuracy_tools/monitor/monitor/distributed/distributed_ops.yaml similarity index 100% rename from debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml rename to debug/accuracy_tools/monitor/monitor/distributed/distributed_ops.yaml diff --git a/debug/accuracy_tools/kj600/kj600/distributed/stack_blacklist.yaml b/debug/accuracy_tools/monitor/monitor/distributed/stack_blacklist.yaml similarity index 77% rename from debug/accuracy_tools/kj600/kj600/distributed/stack_blacklist.yaml rename to debug/accuracy_tools/monitor/monitor/distributed/stack_blacklist.yaml index 00b0013619fcfa1445a8df18c3c7d16764fb4872..40692d6942c39dfd5bbe52d33df6de4f1238eab2 100644 --- a/debug/accuracy_tools/kj600/kj600/distributed/stack_blacklist.yaml +++ b/debug/accuracy_tools/monitor/monitor/distributed/stack_blacklist.yaml @@ -1,5 +1,5 @@ stack: -- kj600/distributed +- monitor/distributed - site-packages/torch/nn/modules/module.py - multiprocessing - debugpy \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py b/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py similarity index 96% rename from debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py rename to debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py index fad007fe35c45b1eb09b8ed986bd9b9893528a75..3f5ebc727dfd306d127ac7605a8aff6e83182f54 100644 --- a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py @@ -1,7 +1,8 @@ import os -import yaml import re import inspect + +import yaml import torch import torch.nn as nn import torch.distributed as dist @@ -77,6 +78,21 @@ class ApiRegistry: else: setattr(api_group, cc_api_name, cc_api_entry_func) + @staticmethod + def redirect_wait(): + global ORIGIN_WAIT + global PENDING_ASYNC_CC_BY_HANDLE + + def wrapped_wait(work): + def wrapped_wait(*args, **kwargs): + ORIGIN_WAIT(*args, **kwargs) + if args[0] in PENDING_ASYNC_CC_BY_HANDLE: + store_func = PENDING_ASYNC_CC_BY_HANDLE.pop(args[0]) + store_func() + return wrapped_wait + + dist.Work.wait = wrapped_wait(dist.Work) + def redirect_api(self): self.set_api_attr(dist, self.distributed_attr_hooked) self.set_api_attr(dist.distributed_c10d, self.distributed_attr_hooked) @@ -92,19 +108,13 @@ class ApiRegistry: for op_name in get_distributed_ops(): self.distributed_attr_hooked[op_name] = DistributedOPTemplate(op_name, pre_hooks, post_hooks) - def redirect_wait(self): - global ORIGIN_WAIT - global PENDING_ASYNC_CC_BY_HANDLE - - def wrapped_wait(work): - def wrapped_wait(*args, **kwargs): - ORIGIN_WAIT(*args, **kwargs) - if args[0] in PENDING_ASYNC_CC_BY_HANDLE: - store_func = PENDING_ASYNC_CC_BY_HANDLE.pop(args[0]) - store_func() - return wrapped_wait - dist.Work.wait = wrapped_wait(dist.Work) +def get_process_group(process_group): + return ( + process_group + if isinstance(process_group, dist.ProcessGroup) + else dist.GroupMember.WORLD + ) def stack_filter(stack): @@ -115,7 +125,7 @@ def stack_filter(stack): def get_callstack(): callstack = [] - for (_, path, line, func, code, _) in inspect.stack(): + for (_, path, line, func, _, _) in inspect.stack(): stack_line = f'{path}[{line}]' if stack_filter(stack_line): callstack.append(stack_line+' '+func) diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/monitor/monitor/features.py similarity index 94% rename from debug/accuracy_tools/kj600/kj600/features.py rename to debug/accuracy_tools/monitor/monitor/features.py index 7810188f7d7df66dce4c489f18062f9381b95646..92c6959a7d56ed3264366bf12444029f00b4df25 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/monitor/monitor/features.py @@ -1,6 +1,6 @@ import torch from torch.autograd.functional import jacobian -from kj600.utils import print_info_log +from monitor.utils import print_info_log @torch.no_grad() @@ -11,6 +11,10 @@ def square_sum(x: torch.tensor): def get_min(x: torch.tensor): return torch.min(x) +@torch.no_grad() +def get_mean(x: torch.tensor): + return torch.mean(x) + @torch.no_grad() def get_norm(x: torch.tensor): return torch.norm(x, p=2) diff --git a/debug/accuracy_tools/atat/core/common/file_check.py b/debug/accuracy_tools/monitor/monitor/file_check.py similarity index 65% rename from debug/accuracy_tools/atat/core/common/file_check.py rename to debug/accuracy_tools/monitor/monitor/file_check.py index 2df825aa35108fc08b9d886bf68f0ef3e2bc1533..af233ca2c8f01b76d912b753078a5cbf117049c6 100644 --- a/debug/accuracy_tools/atat/core/common/file_check.py +++ b/debug/accuracy_tools/monitor/monitor/file_check.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,9 +17,58 @@ import os import re -from atat.core.common.log import logger -from atat.core.common.exceptions import FileCheckException -from atat.core.common.const import FileCheckConst +from monitor.utils import print_info_log + + +class CodedException(Exception): + def __init__(self, code, error_info=""): + super().__init__() + self.code = code + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class FileCheckException(CodedException): + INVALID_FILE_ERROR = 0 + FILE_PERMISSION_ERROR = 1 + SOFT_LINK_ERROR = 2 + ILLEGAL_PATH_ERROR = 3 + ILLEGAL_PARAM_ERROR = 4 + FILE_TOO_LARGE_ERROR = 5 + + err_strs = { + SOFT_LINK_ERROR: "[monitor] 检测到软链接: ", + FILE_PERMISSION_ERROR: "[monitor] 文件权限错误: ", + INVALID_FILE_ERROR: "[monitor] 无效文件: ", + ILLEGAL_PATH_ERROR: "[monitor] 非法文件路径: ", + ILLEGAL_PARAM_ERROR: "[monitor] 非法打开方式: ", + FILE_TOO_LARGE_ERROR: "[monitor] 文件过大: ", + } + + +class FileCheckConst: + """ + Class for file check const + """ + + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r"^[a-zA-Z0-9_./-]+$" + JSON_SUFFIX = ".json" + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + JSON_SUFFIX: MAX_JSON_SIZE, + } class FileChecker: @@ -32,7 +81,10 @@ class FileChecker: ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability file_type(str): The correct file type for file """ - def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + + def __init__( + self, file_path, path_type, ability=None, file_type=None, is_script=True + ): self.file_path = file_path self.path_type = self._check_path_type(path_type) self.ability = ability @@ -42,7 +94,9 @@ class FileChecker: @staticmethod def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + print_info_log( + f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}." + ) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type @@ -82,11 +136,12 @@ class FileOpen: file_path: The file or dictionary path to be opened. mode(str): The file open mode """ + SUPPORT_READ_MODE = ["r", "rb"] SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] - def __init__(self, file_path, mode, encoding='utf-8'): + def __init__(self, file_path, mode, encoding="utf-8"): self.file_path = file_path self.mode = mode self.encoding = encoding @@ -106,9 +161,13 @@ class FileOpen: self._handle.close() def check_file_path(self): - support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + support_mode = ( + self.SUPPORT_READ_MODE + + self.SUPPORT_WRITE_MODE + + self.SUPPORT_READ_WRITE_MODE + ) if self.mode not in support_mode: - logger.error("File open not support %s mode" % self.mode) + print_info_log("File open not support %s mode" % self.mode) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) @@ -135,66 +194,75 @@ class FileOpen: def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): - logger.error('The file path {} is a soft link.'.format(path)) + print_info_log("The file path {} is a soft link.".format(path)) raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) def check_path_length(path, name_length=None): - file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH - if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(path)) > file_max_name_length: - logger.error('The file path length exceeds limit.') + file_max_name_length = ( + name_length if name_length else FileCheckConst.FILE_NAME_LENGTH + ) + if ( + len(path) > FileCheckConst.DIRECTORY_LENGTH + or len(os.path.basename(path)) > file_max_name_length + ): + print_info_log("The file path length exceeds limit.") raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): - logger.error('The file path %s does not exist.' % path) + print_info_log("The file path %s does not exist." % path) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): - logger.error('The file path %s is not readable.' % path) + print_info_log("The file path %s is not readable." % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): - logger.error('The file path %s is not writable.' % path) + print_info_log("The file path %s is not writable." % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): - logger.error('The file path %s is not executable.' % path) + print_info_log("The file path %s is not executable." % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: - logger.error('The file path %s may be insecure because other users have write permissions. ' % path) + print_info_log( + "The file path %s may be insecure because other users have write permissions. " + % path + ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): - logger.error('The file path %s may be insecure because is does not belong to you.' % path) + print_info_log( + "The file path %s may be insecure because is does not belong to you." % path + ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - logger.error('The file path %s contains special characters.' %(path)) + print_info_log("The file path %s contains special characters." % (path)) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - logger.error(f'The size of file path {file_path} exceeds {max_size} bytes.') + print_info_log(f"The size of file path {file_path} exceeds {max_size} bytes.") raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -209,45 +277,32 @@ def check_common_file_size(file_path): def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): - logger.error(f"The {file_path} should be a {file_suffix} file!") + print_info_log(f"The {file_path} should be a {file_suffix} file!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): - logger.error(f"The {file_path} should be a file!") + print_info_log(f"The {file_path} should be a file!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): - logger.error(f"The {file_path} should be a dictionary!") + print_info_log(f"The {file_path} should be a dictionary!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) - - -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - dir_path = os.path.realpath(dir_path) - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - except OSError as ex: - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - 'Failed to create {}. Please check the path permission or disk space .{}'.format(dir_path, str(ex))) from ex - + def check_path_before_create(path): if path_len_exceeds_limit(path): - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.') + raise FileCheckException( + FileCheckException.ILLEGAL_PATH_ERROR, "The file path length exceeds limit." + ) if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - 'The file path {} contains special characters.'.format(path)) + raise FileCheckException( + FileCheckException.ILLEGAL_PATH_ERROR, + "The file path {} contains special characters.".format(path), + ) def change_mode(path, mode): @@ -256,10 +311,14 @@ def change_mode(path, mode): try: os.chmod(path, mode) except PermissionError as ex: - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, - 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + raise FileCheckException( + FileCheckException.FILE_PERMISSION_ERROR, + "Failed to change {} authority. {}".format(path, str(ex)), + ) from ex def path_len_exceeds_limit(file_path): - return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH \ No newline at end of file + return ( + len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH + or len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + ) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/monitor/monitor/module_hook.py similarity index 49% rename from debug/accuracy_tools/kj600/kj600/module_hook.py rename to debug/accuracy_tools/monitor/monitor/module_hook.py index 8043c5671c42675dc55944e4818bce8e9137b455..dca629cb161d4b2eb4e6e330d8dfe3d76c76775e 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/monitor/monitor/module_hook.py @@ -2,19 +2,46 @@ import os import uuid import json from collections import defaultdict +from functools import partial from datetime import datetime import torch + +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +if not torch_version_above_or_equal_2: + raise ValueError("msmonitor require torch>=2.0") + import torch.distributed as dist +from torch import Stream from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook -from kj600.module_spec_verifier import get_config, validate_config_spec -from kj600.optimizer_collect import MixPrecsionOptimizerMon, print_rank_0, OptimizerMonFactory, MegatronDistributedOptimizerMon -from kj600.features import eff_rank, get_sign_matches -from kj600.visualizer import HeatmapVisualizer -from kj600.anomaly_detect import AnomalyScanner, SummaryWriterWithAD -from kj600.anomaly_inform import AnomalyInformFactory -from kj600.module_metric import get_metrics, write_metrics_tensorboard, get_summary_writer_tag_name, TensorMetrics -from kj600.distributed.wrap_distributed import api_register, create_hooks, op_aggregate -from kj600.utils import print_warn_log, print_info_log, get_param_struct +from monitor.module_spec_verifier import validate_config_spec +from monitor.optimizer_collect import OptimizerMon, print_rank_0, OptimizerMonFactory +from monitor.features import eff_rank, get_sign_matches +from monitor.visualizer import HeatmapVisualizer +from monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, \ + BaseWriterWithAD +from monitor.anomaly_inform import AnomalyInformFactory +from monitor.anomaly_analyse import AnomalyDataWriter +from monitor.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, \ + get_summary_writer_tag_name, TensorMetrics, squash_param_name +from monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group +from monitor.utils import print_warn_log, print_info_log, print_error_log, get_param_struct +from monitor.const import Const +from monitor.file_check import FileOpen + +try: + import torch_npu +except ImportError: + pass + + +def param_is_not_tensor_parallel_duplicate(param, tp_group): + return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( + torch.distributed.get_rank(group=tp_group) == 0 + ) + + +def param_is_data_parallel_duplicate(dp_group): + return torch.distributed.get_rank(group=dp_group) != 0 class ModuleHookContext: @@ -30,20 +57,18 @@ class ModuleHookContext: self.focused_out_col = 0 self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found - def set_format_by_arg(self, key_name:str, target_config:dict): + def set_format_by_arg(self, key_name: str, target_config: dict): if key_name in target_config[self.module_name]: self.format_by_arg[key_name] = target_config[self.module_name][key_name] elif key_name in ['input', 'input_grad']: self.ignore_in = True - else: - raise KeyError(f"Missing key: {key_name} of {self.module_name} in config.json") class OptimizerContext: def __init__(self) -> None: self.step = 0 self.param_effective_rank = defaultdict(float) - self.param_mg_direction = defaultdict(float) + self.param_mg_direction = defaultdict(float) self.param_adam_update = defaultdict() self.param_adam_ratio = defaultdict() self.param_weight_grad = defaultdict() @@ -71,30 +96,50 @@ class CommunicationContext: def aggregate(self): self.data = self._agg(self.data) -class TrainerMon: +class GradContext: + def __init__(self) -> None: + self.pre = [] + self.post = [] + self.acc_metric = [] + self.acc = {} + self.actv = defaultdict(dict) + + def reset(self): + self.pre.clear() + self.post.clear() + self.acc_metric.clear() + self.acc.clear() + self.actv.clear() + + +class TrainerMon: tensor_metrics = TensorMetrics() - # opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer" - def __init__(self, config_file_path, params_have_main_grad=True, opt_ty=None) -> None: + def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None: self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) + self.grad_context = GradContext() + self.process_group = get_process_group(process_group) self.params_have_main_grad = params_have_main_grad - self.config = get_config(config_file_path) + with FileOpen(config_file_path, 'r') as f: + self.config = json.load(f) self.module_rank_list = self.config.get("module_ranks", []) + self.format = self.config.get('format', 'tensorboard') self.eps = self.config.get('eps', 1e-8) self.ops = self.config.get('ops', []) + self.ndigits = self.config.get('ndigits', 6) self.xy_distribution = self.config.get('xy_distribution', False) if not self.xy_distribution: print_rank_0("> module input/output input_grad/output_grad is not monitored. ") - - # backward hook cause megatron-lm pipeline parallel schedule assert exception. + # backward hook cause megatron-lm pipeline parallel schedule assert exception. # TBD: backward hook cause output tensor is view of some base tensor. root cause invesigation pending. - self.forward_only = self.config.get('forward_only', False) - if self.forward_only: + self.forward_only = self.config.get('forward_only', False) + if self.forward_only: print_rank_0("> only module forward is monitored. ") + self.backward_only = self.config.get('backward_only', False) self.ur_distribution = self.config.get('ur_distribution', False) if not self.ur_distribution: @@ -120,28 +165,72 @@ class TrainerMon: api_register.initialize_hook(*create_hooks(context=self.cc_context, monitor=self)) api_register.redirect_api() - alert_setting = self.config.get('alert', {"rules":[]}) + alert_setting = self.config.get('alert', {"rules": []}) self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) - - anomaly_inform = AnomalyInformFactory.create_informer(**alert_setting["inform"]) if "inform" in alert_setting else None - - self.optimizer_hooked = False - output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') + anomaly_inform = AnomalyInformFactory.create_informer( + **alert_setting["inform"]) if "inform" in alert_setting else None + + output_base_dir = os.getenv('MONITOR_OUTPUT_DIR', './monitor_output') cur_time = datetime.now().strftime('%b%d_%H-%M-%S') unique_id = str(uuid.uuid4())[:8] + if dist.is_initialized(): - if (dist.get_rank() in self.module_rank_list) or len(self.module_rank_list) == 0: - self.summary_writer = SummaryWriterWithAD( - os.path.join(output_base_dir, f"{cur_time}-rank{dist.get_rank()}-{unique_id}"), self.alert_rules, unique_id, anomaly_inform) + rank = dist.get_rank() + tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}") + pp_stage = dist.get_group_rank(self.process_group, rank) + group_mates = dist.get_process_group_ranks(self.process_group) else: - self.summary_writer = SummaryWriterWithAD(os.path.join(output_base_dir, f"{cur_time}-{unique_id}"), self.alert_rules, unique_id, anomaly_inform) + rank = 0 + tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}") + pp_stage = 0 + group_mates = [0] + self.rank = rank + + # 初始化AnomalyData工厂 + self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates) if alert_setting.get('dump', + False) else None + + if self.format == 'tensorboard': + writer = SummaryWriterWithAD + self.write_metrics = write_metrics_tensorboard + elif self.format == 'csv': + writer = CSVWriterWithAD + self.write_metrics = write_metrics_csv + elif self.format == 'api': + writer = BaseWriterWithAD + self.write_metrics = write_metrics_tensorboard + + if (rank in self.module_rank_list) or len(self.module_rank_list) == 0: + + self.summary_writer = writer( + tensorboard_dir, + self.alert_rules, + unique_id, + anomaly_inform, + self.anomaly_data_factory, + self.ndigits + ) + # 初始化anomaly deteted文件目录 + if self.anomaly_data_factory: + self.anomaly_data_writer = AnomalyDataWriter( + os.path.join(output_base_dir, "anomaly_detected"), rank) + self.anomaly_data_writer.init_detected_json() + # A HeatmapVisualizer instance is associated with an image self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.micro_batch_number = 0 + self.micro_batch_number = 1 + + self.weight_hooked = False + self.optimizer_hooked = False + self.param_registered = False + self.vpp = False + self.dp_group = None + self.tp_group = None - self.param_name_list = [] self.param2name = defaultdict(str) + self.param_name_call_id = {} + self.call_id = 0 self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty) if opt_ty is None: @@ -149,21 +238,25 @@ class TrainerMon: raise Exception("ur_distribution cannot be enabled with unknown optimizer.") if self.mv_distribution: raise Exception("mv_distribution cannot be enabled with unknown optimizer.") + self.verbose = False self.print_struct = self.config.get("print_struct", False) + if self.print_struct: + self.verbose = True self.struct_printed = False - self.module_struct = {} + self.module_struct = defaultdict(dict) + return def __del__(self): if hasattr(self, "summary_writer"): self.summary_writer.close() - + @staticmethod def set_wrapped_optimizer(_wrapped_optimizer): - MixPrecsionOptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) + OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) @staticmethod - def adhoc_check(target_tensor:torch.tensor, module_name:str, tensor_name:str, rank_list, ops_list): + def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list): rank = None if dist.is_initialized(): rank = dist.get_rank() @@ -171,60 +264,97 @@ class TrainerMon: return TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank) - def hook_modules(self, model:torch.nn.Module, grad_acc_steps): - # fwd=0, bkd=1 - # targets is module name list like ["xx.xxx1", "xxx.xxx2"] which can be obtained when first run. - print_rank_0("> module names:") - for name, _ in model.named_modules(): - print_rank_0(f"\t{name}") - self.micro_batch_number = grad_acc_steps + @staticmethod + def build_tbtag_tensor_map(module_name, tag, tensor): + metrics = {} + rank = dist.get_rank() if dist.is_initialized() else None + key = get_summary_writer_tag_name(module_name, tag, rank) + if tensor is not None: + metrics[key] = tensor + return metrics - if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): - targets = [x for x, _ in model.named_modules()] if self.print_struct else self.config['targets'].keys() - hooked_count = self._hook_module(targets, model, fwd_or_bkd=0) - print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") - else: + @staticmethod + def generate_cc_metrics(cc_name, cc_tensor): + metrics = defaultdict(dict) + rank = dist.get_rank() if dist.is_initialized() else None + for op, tag2tensor in cc_tensor.data.items(): + for tag, tensor in tag2tensor.items(): + key = get_summary_writer_tag_name(cc_name, tag, rank) + metrics[op].update({key: tensor}) + cc_tensor.reset() + return metrics + + def hook_modules(self, model: torch.nn.Module, grad_acc_steps): + if self.module_rank_list and (self.rank not in self.module_rank_list): return + if not isinstance(model, list): + model = [model] + + self._register_param_name(model) + + self.micro_batch_number = grad_acc_steps + for vpp_stage, model_chunk in enumerate(model): + vpp_stage = f'{vpp_stage}{Const.vpp_sep}' if self.vpp else '' + targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ + 'targets'].keys() + hooked_count = self._hook_module(targets, model_chunk, vpp_stage) + print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") + if not self.optimizer_hooked: - self.optimizer_hooked = True - print_rank_0("> parameter names:") - for name, param in model.named_parameters(): - print_rank_0(f"\t{name}") - for target_module, _ in self.config['targets'].items(): - if name.startswith(target_module): # name : language_model.encoder.layers.0.mlp.weight, target_module:language_model.encoder.layers.0 - self.param_name_list.append(name) - self.param2name[param] = name self.hook_optimizer() return - def build_tbtag_tensor_map(self, module_name, tag, tensor): - metrics = {} - rank = dist.get_rank() if dist.is_initialized() else None - key = get_summary_writer_tag_name(module_name, tag, rank) - if tensor is not None: - metrics[key] = tensor - return metrics + def generate_wgrad_metrics(self): + if not self.wg_distribution: + return {}, {} + + unreduced = {} + if self.weight_hooked: + for metric_name in self.ops: + unreduced[metric_name] = get_metrics(metric_name, self.grad_context.acc, self.eps) + self.grad_context.acc_metric = [unreduced] + + grad_dict = {} + for param, name in self.param2name.items(): + if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): + continue + if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): + continue + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue + key = get_summary_writer_tag_name(name, 'post_grad', self.rank) + grad_dict[key] = grad + + reduced = {op: get_metrics(op, grad_dict, self.eps) for op in self.ops} + self.grad_context.post = [reduced] + + return reduced, unreduced + + def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): + print_info_log(f'grad acc steps {grad_acc_steps}') + self.hook_optimizer(optimizer) + self.micro_batch_number = grad_acc_steps + self.backward_only = True + + self.dp_group = dp_group + self.tp_group = tp_group + + self._register_param_name(model) + self._hook_weights() + self.hook_modules(model, grad_acc_steps) def generate_param_metrics(self, tag, param_tensor): metrics = {} rank = dist.get_rank() if dist.is_initialized() else None - for param, name in self.param2name.items(): + for _, name in self.param2name.items(): key = get_summary_writer_tag_name(name, tag, rank) if name not in param_tensor or param_tensor[name] is None: continue metrics[key] = param_tensor[name] return metrics - - def generate_cc_metrics(self, cc_name, cc_tensor): - metrics = defaultdict(dict) - rank = dist.get_rank() if dist.is_initialized() else None - for op, tag2tensor in cc_tensor.data.items(): - for tag, tensor in tag2tensor.items(): - key = get_summary_writer_tag_name(cc_name, tag, rank) - metrics[op].update({key: tensor}) - cc_tensor.reset() - return metrics def write_adhoc_check(self, step): TrainerMon.tensor_metrics.flush(self.summary_writer) @@ -233,37 +363,47 @@ class TrainerMon: if not self.xy_distribution: return for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + if len(fwd_context.actv) == 0: + continue if not len(fwd_context.actv) == self.micro_batch_number: - print_warn_log(f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}") - for metric_name in self.ops: - write_metrics_tensorboard(metric_name, self.summary_writer, fwd_context.actv, step) + print_warn_log( + f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}") + self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv') fwd_context.actv.clear() - for _, bwd_context in self.module_bwd_hook_context_by_module.items(): - if not len(bwd_context.actvgrad) == self.micro_batch_number: - print_warn_log(f"bwd_context.actvgrad not equal to micro_batch_number: {len(bwd_context.actvgrad)}, {self.micro_batch_number}") - for metric_name in self.ops: - write_metrics_tensorboard(metric_name, self.summary_writer, bwd_context.actvgrad, step) - bwd_context.actvgrad.clear() + self.write_metrics(self.ops, self.summary_writer, [self.grad_context.actv], step, 'grad_actv') - def hook_optimizer(self): + def write_grad_tb(self, step): + if not self.wg_distribution: + return + + self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced') + self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced') + + def hook_optimizer(self, optimizer=None): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] - if self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed: + if self.print_struct and not all( + value == {} for value in self.module_struct.values()) and not self.struct_printed: self._smallest_rank_print("> module struct:") self._smallest_rank_print(json.dumps(self.module_struct, indent=4)) if not self.cc_log_only: raise Exception("exit after first step when print model struct") if self.cc_log_only and context.step > 0: self._smallest_rank_print("> Used communication ops and corresponding stack") - self._smallest_rank_print(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4)) + self._smallest_rank_print( + json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}, indent=4)) raise Exception("exit after first step when print cc stack") - - context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv(self, - optimizer, self.param2name) - + self.generate_wgrad_metrics() + + mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) + context.param_exp_avg = mv_result.exp_avg + context.param_exp_avg_sq = mv_result.exp_avg_sq + context.param_adam_update = mv_result.update + context.param_adam_ratio = mv_result.ratio + for param, name in self.param2name.items(): if "params_effrank" in self.config and name in self.config["params_effrank"]: context.param_effective_rank[name] = eff_rank(param.detach()) @@ -271,9 +411,8 @@ class TrainerMon: if grad is None: print_warn_log(f"grad is None: {name}, maybe something wrong happened.") continue - if self.wg_distribution: - context.param_weight_grad[name] = grad - if self.mg_direction: + + if self.mg_direction: if context.step == 0: same_direction_ratio = torch.tensor(1.) else: @@ -281,15 +420,11 @@ class TrainerMon: context.param_mg_direction[name] = same_direction_ratio tbtag_tensor_map = {} - if self.wg_distribution: - tbtag_tensor_map.update(self.generate_param_metrics('weight_grad', context.param_weight_grad)) if self.mv_distribution: tbtag_tensor_map.update(self.generate_param_metrics('exp_avg', context.param_exp_avg)) tbtag_tensor_map.update(self.generate_param_metrics('exp_avg_sq', context.param_exp_avg_sq)) if self.mg_direction: tbtag_tensor_map.update(self.generate_param_metrics('mg_direction', context.param_mg_direction)) - # if not tbtag_tensor_map: - # return metric_dict = {} for metric_name in self.ops: metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) @@ -298,6 +433,7 @@ class TrainerMon: cc_metrics = self.generate_cc_metrics(k, c) for op, m in cc_metrics.items(): metric_dict[op].update(m) + if not metric_dict: return context.metric_list.append(metric_dict) @@ -307,29 +443,57 @@ class TrainerMon: context = self.optimizer_context[optimizer] rank = dist.get_rank() if dist.is_initialized() else None + if self.anomaly_data_factory: + self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) + self.write_grad_tb(context.step) self.write_adhoc_check(context.step) if self.ur_distribution: for param_name, _ in context.param_adam_update.items(): - self.update_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer) + self.update_heatmap_visualizer[param_name].visualize( + get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer) for param_name, _ in context.param_adam_ratio.items(): - self.ratio_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) + self.ratio_heatmap_visualizer[param_name].visualize( + get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) - for metric_name in self.ops: - if not context.metric_list: - break - write_metrics_tensorboard(metric_name, self.summary_writer, context.metric_list, context.step) + if context.metric_list: + self.write_metrics(self.ops, self.summary_writer, context.metric_list, context.step, 'other') context.metric_list.clear() context.step += 1 + self.grad_context.reset() + if self.anomaly_data_factory: + self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) + self.summary_writer.clear_anomalies() + self.call_id = 0 + self.param_name_call_id.clear() + return + + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + optimizer_pre_step_hook(optimizer, args, kwargs) + out = func(*args, **kwargs) + optimizer_post_step_hook(optimizer, args, kwargs) + return out + return wrapper + + if self.optimizer_hooked: return - if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): - register_optimizer_step_pre_hook(optimizer_pre_step_hook) - register_optimizer_step_post_hook(optimizer_post_step_hook) + + if optimizer: + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + + else: + if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + self.optimizer_hooked = True return def _smallest_rank_print(self, msg): + if not self.verbose: + return if dist.is_initialized(): if self.module_rank_list: if dist.get_rank() == min(self.module_rank_list): @@ -340,7 +504,34 @@ class TrainerMon: else: print_info_log(msg) - def _hook_module(self, target_names, module: torch.nn.Module, fwd_or_bkd): + def _register_param_name(self, model): + if self.param_registered: + return + if not isinstance(model, list): + model = [model] + + if len(model) > 1: + self.vpp = True + self._smallest_rank_print('vpp enabled') + + for vpp_stage, model_chunk in enumerate(model): + prefix = f'{Const.vpp}{vpp_stage}{Const.vpp_sep}' if self.vpp else '' + for param_name, param in model_chunk.named_parameters(): + name = prefix + squash_param_name(param_name) + for target in self.config['targets'].keys(): + if param_name.startswith(target) and param.requires_grad: + self._smallest_rank_print(f'>> monitoring: {name}') + setattr(param, "zero_out_wgrad", True) + if name in self.param2name.values() or name == '': + print_error_log(f'same name {name} for different param. Current param is {param_name}. \ + May be error of squash_param_name') + raise Exception("param with same name will be overwriten.") + self.param2name[param] = name + break + + self.param_registered = True + + def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 @@ -351,15 +542,15 @@ class TrainerMon: self.module_struct[context.module_name].update( {"input": f"{get_param_struct(module_input)}", "output": f"{get_param_struct(module_output)}"}) return - if not self.xy_distribution: - return if not context.format_by_arg: context.set_format_by_arg('input', self.config['targets']) context.set_format_by_arg('output', self.config['targets']) if not context.verified: if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg['input'], module_input, context.module_name, 'input') - context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output, context.module_name, 'output') + context.focused_in_col = validate_config_spec(context.format_by_arg['input'], module_input, + context.module_name, 'input') + context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output, + context.module_name, 'output') context.verified = True # expect output be tensor type tbtag_tensor_map = {} @@ -387,32 +578,40 @@ class TrainerMon: context: ModuleHookContext = self.module_bwd_hook_context_by_module[module] if self.print_struct: self.module_struct[context.module_name].update( - {"input_grad": f"{get_param_struct(input_grad)}", "output_grad": f"{get_param_struct(output_grad)}"}) - return - if not self.xy_distribution: + {"input_grad": f"{get_param_struct(input_grad)}", + "output_grad": f"{get_param_struct(output_grad)}"}) return if not context.format_by_arg: context.set_format_by_arg('input_grad', self.config['targets']) context.set_format_by_arg('output_grad', self.config['targets']) + if not context.format_by_arg: + return if not context.verified: if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg['input_grad'], input_grad, context.module_name, 'input_grad') - context.focused_out_col = validate_config_spec(context.format_by_arg['output_grad'], output_grad, context.module_name, 'output_grad') + context.focused_in_col = validate_config_spec(context.format_by_arg['input_grad'], input_grad, + context.module_name, 'input_grad') + context.focused_out_col = validate_config_spec(context.format_by_arg['output_grad'], output_grad, + context.module_name, 'output_grad') context.verified = True tbtag_tensor_map = {} if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'input_grad', cared_input_grad)) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(context.module_name + f'_{context.micro_step}', f'input_grad', + cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output_grad', cared_output_grad)) - metric_dict = {} - for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(context.module_name + f'_{context.micro_step}', f'output_grad', + cared_output_grad)) + if context.micro_step == 0 and context.actvgrad: - print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + print_warn_log( + f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") context.actvgrad.clear() - context.actvgrad.append(metric_dict) + + for metric_name in self.ops: + self.grad_context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) context.micro_step += 1 if context.micro_step == self.micro_batch_number: @@ -420,15 +619,45 @@ class TrainerMon: context.step += 1 return + if self.backward_only and self.forward_only: + print_warn_log('not enable backward_only and forward_only simultaneously') + hooked_count = 0 - for name, submodule in module.named_modules(): - self.module_struct[name] = {} - if name in target_names: - submodule.register_forward_hook(fwd_hook_fun) - self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) - if not self.forward_only: - submodule.register_full_backward_hook(bwd_hook_fun) - self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) - print_rank_0(f"> {name} is monitored successfully") - hooked_count += 1 + if self.xy_distribution or self.print_struct: + for module_name, submodule in module.named_modules(): + name = vpp_stage + module_name + self.module_struct[name] = {} + if name in target_names or module_name in target_names: + if not self.backward_only: + submodule.register_forward_hook(fwd_hook_fun) + self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) + if not self.forward_only: + submodule.register_full_backward_hook(bwd_hook_fun) + self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) + print_rank_0(f"> {name} is monitored successfully") + hooked_count += 1 return hooked_count + + def _hook_weights(self): + context = self.grad_context + + @torch.no_grad + def param_hook(*args, context_dict, param, key, name): + param.micro_step += 1 + self.param_name_call_id[name] = self.call_id + self.call_id += 1 + if param.micro_step == self.micro_batch_number: + param.micro_step = 0 + if self.params_have_main_grad: + context_dict[key] = param.main_grad.clone() + else: + context_dict[key] = param.grad.clone() + + for param, name in self.param2name.items(): + key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) + setattr(param, 'micro_step', 0) + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(partial(param_hook, context_dict=context.acc, param=param, key=key, name=name)) + + self.weight_hooked = True diff --git a/debug/accuracy_tools/kj600/kj600/module_metric.py b/debug/accuracy_tools/monitor/monitor/module_metric.py similarity index 65% rename from debug/accuracy_tools/kj600/kj600/module_metric.py rename to debug/accuracy_tools/monitor/monitor/module_metric.py index e09536b072cf7953e6b6106420936416d4264d0e..681e9d90829a3333302a85b7867ff5c91cc51dcf 100644 --- a/debug/accuracy_tools/kj600/kj600/module_metric.py +++ b/debug/accuracy_tools/monitor/monitor/module_metric.py @@ -1,15 +1,26 @@ import math +import re import statistics -from kj600.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm +from monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank): if rank is None: return f"{module_or_param_name}/{tag}" else: - return f"{module_or_param_name}/{rank}/{tag}" - + return f"{module_or_param_name}/rank{rank}/{tag}" + +def squash_param_name(param_name): + name = '' + for pattern in ['(?<=layers\.)[\d]*.*', 'embeddings?\.(.*)', 'final.*', 'output.*','norm.*']: + match = re.findall(pattern, param_name) + if match: + name += match[0] + break + if name == '': + name = param_name + return name # 用于存储所有metric实现类的注册表 config_metric_registry = {} @@ -28,9 +39,9 @@ class TensorMetrics: self.metrics = {} #tensor_tag --> [] self.cur_idx = {} - fun_map = {"norm": get_norm, "max": get_max, "min": get_min} + fun_map = {"norm": get_norm, "max": get_max, "min": get_min, "mean": get_mean} #get stats and insert into metrics dictionary - def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8): + def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank): prefix = get_summary_writer_tag_name(module_name, tensor_name, rank) for stat_op in stat_ops: y = TensorMetrics.fun_map[stat_op](tensor) @@ -75,6 +86,19 @@ class MinMetric(Metric): summary_writer.add_scalar(f'{key}_min', min_value, step) +@register_config_metric("mean") +class MeanMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return get_mean(tensor) + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + for key in metric_value[0][metric_name].keys(): + mean_value = sum([item[metric_name][key].item() for item in metric_value]) / len(metric_value) + summary_writer.add_scalar(f'{key}_mean', mean_value, step) + + @register_config_metric("max") class MaxMetric(Metric): @staticmethod @@ -129,17 +153,16 @@ class NaNsMetric(Metric): class IdentMetric(Metric): @staticmethod def get_metric_value(tensor, eps): - if tensor.dim() != 0: - return None - return tensor + if tensor.dim() == 0: + return tensor @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): #metric_value is a dict, key is parameter name and value is a list of scalar tensor + def metric_tensorboard(metric_name, summary_writer, metric_value, context): #metric_value is a dict, key is parameter name and value is a list of scalar tensor if len(metric_value) == 1: for key, value in metric_value[0][metric_name].items(): if not value: continue - summary_writer.add_scalar(f'{key}_identical', value.item(), step) + summary_writer.add_scalar(f'{key}_identical', value.item(), context) def get_metrics(metric_name, tag2tensor, eps): @@ -150,9 +173,32 @@ def get_metrics(metric_name, tag2tensor, eps): raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e -def write_metrics_tensorboard(metric_name, summary_writer, metric_value, step): - try: - fun_metric = config_metric_registry[metric_name] - return fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) - except KeyError as e: - raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e +def write_metrics_tensorboard(ops, summary_writer, metric_value, step, prefix=''): + for metric_name in ops: + try: + fun_metric = config_metric_registry[metric_name] + fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) + except KeyError as e: + raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e + +def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''): + for metric_name in ops: + try: + fun_metric = config_metric_registry[metric_name] + fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) + + except KeyError as e: + raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e + + if not summary_writer.header: + if prefix in ['actv', 'grad_actv']: + summary_writer.header = ['param_name'] + ['input_'+op for op in ops] + ['output_'+op for op in ops] + else: + summary_writer.header = ['param_name'] + ops + + for key in metric_value[0][ops[0]].keys(): + if 'vpp' in key: + summary_writer.header.insert(0, 'vpp_stage') + break + summary_writer.write_csv(prefix, step) + summary_writer.header = [] diff --git a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py b/debug/accuracy_tools/monitor/monitor/module_spec_verifier.py similarity index 91% rename from debug/accuracy_tools/kj600/kj600/module_spec_verifier.py rename to debug/accuracy_tools/monitor/monitor/module_spec_verifier.py index 395aa82f17a87cdf742a8294e29ccb1c32081200..062d5c230ef9d438ab27b3e704b12f3828e7f076 100644 --- a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py +++ b/debug/accuracy_tools/monitor/monitor/module_spec_verifier.py @@ -2,15 +2,8 @@ import json import re import abc import torch -from kj600.utils import check_file_valid_readable -def get_config(file_path='config.json'): - check_file_valid_readable(file_path) - with open(file_path, 'r') as file: - config = json.load(file) - return config - # 用于存储所有validator实现类的注册表 config_validator_registry = {} @@ -40,7 +33,6 @@ class TensorValidator(ConfigValidator): def validate(self, actual_data, module_name:str, data_type:str, pattern_match): if not torch.is_tensor(actual_data): raise ValueError(f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.") - return None @register_config_validator diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/monitor/monitor/optimizer_collect.py similarity index 51% rename from debug/accuracy_tools/kj600/kj600/optimizer_collect.py rename to debug/accuracy_tools/monitor/monitor/optimizer_collect.py index 285f17ca6dc6a00814b0847c7d203524d8a8caa6..7d2f488efad18d23ca48063f9e21193f24c9e05e 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/monitor/monitor/optimizer_collect.py @@ -1,10 +1,12 @@ -from collections import defaultdict +from abc import ABC, abstractmethod +from collections import defaultdict, namedtuple import torch import torch.distributed as dist -from kj600.visualizer import HeatmapVisualizer +from monitor.utils import print_warn_log -def print_rank_0(message, debug=False, force=False): + +def print_rank_0(message): if dist.is_initialized(): if dist.get_rank() == 0: print(message) @@ -12,20 +14,29 @@ def print_rank_0(message, debug=False, force=False): print(message) -class MixPrecsionOptimizerMon: +MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) + + +class OptimizerMon(ABC): wrapped_optimizer = None + @classmethod + def set_wrapped_optimizer(cls, wrapped_optimizer): + cls.wrapped_optimizer = wrapped_optimizer + + @abstractmethod + def fetch_mv(self, monitor, torch_opt, params2name): + pass + + +class MixPrecisionOptimizerMon(OptimizerMon): def __init__(self) -> None: self.fp16_to_fp32_param = {} - @staticmethod - def set_wrapped_optimizer(_wrapped_optimizer): - MixPrecsionOptimizerMon.wrapped_optimizer = _wrapped_optimizer - # parameter tensors we want to monitor and their names are in params2name_dict # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer + mix_prec_opt = self.wrapped_optimizer if not self.fp16_to_fp32_param and mix_prec_opt is not None: for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups): @@ -44,8 +55,12 @@ class MixPrecsionOptimizerMon: param = self.fp16_to_fp32_param[param] if param in torch_opt.state: - exp_avg = torch_opt.state[param]["exp_avg"] - exp_avg_sq = torch_opt.state[param]["exp_avg_sq"] + state_param = torch_opt.state.get(param, None) + exp_avg = state_param.get("exp_avg", None) + exp_avg_sq = state_param.get("exp_avg_sq", None) + if exp_avg is None or exp_avg_sq is None: + print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") + continue if monitor.mv_distribution: exp_avg_dict[name] = exp_avg exp_avg_sq_dict[name] = exp_avg_sq @@ -53,15 +68,15 @@ class MixPrecsionOptimizerMon: exp_avg_dict[name] = exp_avg if monitor.ur_distribution: update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) - ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq) + ratio_dict[name] = (exp_avg / torch.sqrt(exp_avg_sq)).nan_to_num(0) monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) -class MegatronDistributedOptimizerMon(MixPrecsionOptimizerMon): +class MegatronDistributedOptimizerMon(MixPrecisionOptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): - mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer + mix_prec_opt = self.wrapped_optimizer if not (hasattr(mix_prec_opt, "model_float16_groups") and hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")): raise Exception("megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, \ if not, please check megatron-lm version") @@ -73,18 +88,48 @@ class MegatronDistributedOptimizerMon(MixPrecsionOptimizerMon): return self._fetch_mv_in_adam(params2name, torch_opt, monitor) -class DummyOptimizerMon(MixPrecsionOptimizerMon): +class MegatronFP32OptimizerMon(OptimizerMon): + def fetch_mv(self, monitor, torch_opt, params2name): + exp_avg_dict = defaultdict(float) + exp_avg_sq_dict = defaultdict(float) + update_dict = defaultdict() + ratio_dict = defaultdict() + + for param, name in params2name.items(): + if param in torch_opt.state: + state_param = torch_opt.state.get(param, None) + exp_avg = state_param.get("exp_avg", None) + exp_avg_sq = state_param.get("exp_avg_sq", None) + if exp_avg is None or exp_avg_sq is None: + print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") + continue + if monitor.mv_distribution: + exp_avg_dict[name] = exp_avg + exp_avg_sq_dict[name] = exp_avg_sq + if monitor.mg_direction: + exp_avg_dict[name] = exp_avg + if monitor.ur_distribution: + update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) + ratio_dict[name] = (exp_avg / torch.sqrt(exp_avg_sq)).nan_to_num(0) + monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) + monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict) + + +class DummyOptimizerMon(OptimizerMon): def fetch_mv(self, monitor, torch_opt, params2name): - return None, None, None, None + return MVResult(exp_avg=None, exp_avg_sq=None, update=None, ratio=None) class OptimizerMonFactory: @staticmethod - def create_optimizer_mon(opt_ty:str): + def create_optimizer_mon(opt_ty: str): if opt_ty == "Megatron_Float16OptimizerWithFloat16Params": - return MixPrecsionOptimizerMon() + return MixPrecisionOptimizerMon() if opt_ty == "Megatron_DistributedOptimizer": return MegatronDistributedOptimizerMon() + if opt_ty == "Megatron_FP32Optimizer": + return MegatronFP32OptimizerMon() if opt_ty is None or opt_ty == "unknown": return DummyOptimizerMon() raise Exception("opt_ty should be Megatron_Float16OptimizerWithFloat16Params or Megatron_DistributedOptimizer or None or unknown") diff --git a/debug/accuracy_tools/monitor/monitor/unittest/test_monitor.py b/debug/accuracy_tools/monitor/monitor/unittest/test_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..a7017a6be5e01ad1b332c08e2332a47eee1de023 --- /dev/null +++ b/debug/accuracy_tools/monitor/monitor/unittest/test_monitor.py @@ -0,0 +1,145 @@ +import sys +import os +import re +import argparse +import pandas as pd +from glob import glob +from collections import defaultdict + + +def parse_logfile(logfile): + grad_norm = [] + step = [] + with open(logfile) as f: + for line in f.readlines(): + if 'consumed samples' in line: + grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0])) + # step = int(re.findall('(?<=iteration)[ \d]*', line)[0]) + return grad_norm + + +def parse_monitor_output(output_dir): + reduced = {} + unreduced = {} + for dir in glob(output_dir+'*'): + rank = int(re.findall('(?<=rank)[\d]*', dir)[0]) + unreduced[rank] = [] + reduced[rank] = [] + for file in os.listdir(dir): + # step = int(re.search("(?<=reduced\_)[\d]*", file)[0]) + # if step != 0: + # continue + df = pd.read_csv(os.path.join(dir, file)) + if '_unreduced_' in file: + unreduced[rank].append(df) + pass + elif '_reduced_' in file: + reduced[rank].append(df) + else: + print(f'unexpected file {file} in {dir}') + return reduced, unreduced + +def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel): + steps = len(reduced[0]) + world_size = len(reduced) + errors = [] + for index, row in unreduced[0][0].iterrows(): + param = row['param_name'] + is_tp_duplicate = False + for step in range(2): + # sum reduced + reduced_mean = 0. + for rank in range(world_size): + if len(reduced[rank]) == 0: + continue + df = reduced[rank][step] + value = list(df[df['param_name'] == param]['mean']) + if value == []: + if step == 0: + is_tp_duplicate = True + continue + reduced_mean += value[0] + + # sum unreduced + unreduced_mean = 0. + for rank in range(world_size): + df = unreduced[rank][step] + value = list(df[df['param_name'] == param]['mean']) + if value == []: + continue + unreduced_mean += list(df[df['param_name'] == param]['mean'])[0] + + unreduced_mean /= dp_size + if is_tp_duplicate and (not sequence_parallel or 'embedding' in param): + unreduced_mean /= tp_size + try: + assert_equal(unreduced_mean, reduced_mean) + except AssertionError as e: + errors.append([param, step, e, is_tp_duplicate]) + if errors: + print(errors) + else: + print(f'grad mean is in consist between unreduced grad and reduced grad monitord.') + + + +def assert_equal(a, b): + if b == 0 or a == 0: + return + if b == 0: + rel_diff = a + elif a == 0: + rel_diff = b + else: + rel_diff = abs(a/b-1) + assert rel_diff<0.01, f'{a}, {b}, {rel_diff}' + + +def valid_total_norm(total_norm, reduced, duplicate_embedding): + steps = len(total_norm) + world_size = len(reduced) + errors = [] + for step in range(steps): + calculated_norm = 0. + for rank in range(world_size): + if len(reduced[rank]) == 0: + if step == 0: + print(f'rank {rank} is duplicated in dp group') + continue + for index, row in reduced[rank][step].iterrows(): + if duplicate_embedding and 'word_embedding' in row['param_name']: + continue + calculated_norm += row['norm']**2 + try: + assert_equal(calculated_norm**0.5, total_norm[step]) + except AssertionError as e: + errors.append([step, e]) + if errors: + print('total norm errors: ', errors) + else: + print('grad norm in consist between training log and reduced gradients monitored') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--monitor_output', '-m', type=str, required=True, help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16') + parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file') + parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size') + parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size') + parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size') + parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False, help='whether untie_embeddings_and_output_weights in pp parallel') + parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False, help='whether sequence parallel is enabled. Add -s to store true') + + args = parser.parse_args() + + assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1' + assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1' + assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1' + + total_norm = parse_logfile(args.logfile) + reduced, unreduced = parse_monitor_output(args.monitor_output) + + duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1 + + valid_total_norm(total_norm, reduced, duplicate_embedding) + valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel) \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/monitor/monitor/utils.py similarity index 73% rename from debug/accuracy_tools/kj600/kj600/utils.py rename to debug/accuracy_tools/monitor/monitor/utils.py index 53d47d9988647202bdb711afde38b94b51899b5a..3aed6911c440da507a0b3ec8a02091308061e84b 100644 --- a/debug/accuracy_tools/kj600/kj600/utils.py +++ b/debug/accuracy_tools/monitor/monitor/utils.py @@ -107,4 +107,29 @@ def check_file_valid_readable(path): def check_file_valid_writable(path): check_file_valid(path) check_path_writability(path) - \ No newline at end of file + + +def make_file_safety(file_path: str, permission=0o640): + if os.path.islink(file_path): + raise RuntimeError(f"Invalid soft link path: {file_path}") + file_real_path = os.path.realpath(file_path) + if os.path.exists(file_real_path): + return + parent_path = os.path.dirname(file_real_path) + if not os.path.exists(parent_path): + os.makedirs(parent_path, mode=0o750, exist_ok=True) + if not os.access(parent_path, os.W_OK): + raise PermissionError(f"The path {parent_path} is not writable!") + try: + os.close(os.open(file_real_path, os.O_WRONLY | os.O_CREAT, permission)) + except OSError as e: + raise RuntimeError("Can't create file: " + file_real_path) from e + os.chmod(file_real_path, permission) + + +def create_directory(dir_path): + dir_path = os.path.realpath(dir_path) + try: + os.makedirs(dir_path, mode=0o750, exist_ok=True) + except OSError as ex: + raise RuntimeError("Failed to create directory. Please check the path permission or disk space.") from ex \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/visualizer.py b/debug/accuracy_tools/monitor/monitor/visualizer.py similarity index 97% rename from debug/accuracy_tools/kj600/kj600/visualizer.py rename to debug/accuracy_tools/monitor/monitor/visualizer.py index e1929bfa3fb338b1cb66cda80a128e83176bfcbf..151f1ea1c451a5f27d250df591c4d00f64a1a34c 100644 --- a/debug/accuracy_tools/kj600/kj600/visualizer.py +++ b/debug/accuracy_tools/monitor/monitor/visualizer.py @@ -1,7 +1,7 @@ import torch import numpy as np import matplotlib.pyplot as plt -from kj600.features import cal_histc +from monitor.features import cal_histc class HeatmapVisualizer: diff --git a/debug/accuracy_tools/kj600/pyproject.toml b/debug/accuracy_tools/monitor/pyproject.toml similarity index 78% rename from debug/accuracy_tools/kj600/pyproject.toml rename to debug/accuracy_tools/monitor/pyproject.toml index 5df968563345dd07ed477ec73b967b63c6e812a6..5111fbbbd53a9e88bd52f12a61621d2dc3ed6203 100644 --- a/debug/accuracy_tools/kj600/pyproject.toml +++ b/debug/accuracy_tools/monitor/pyproject.toml @@ -3,11 +3,9 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] -name = "kj600" +name = "monitor" version = "0.0.1" dependencies = [ - "torch", - "torch_npu", "torchvision", "tensorboard", "matplotlib", @@ -16,4 +14,7 @@ dependencies = [ ] [tool.setuptools.packages] -find = {} # Scan the project directory with the default parameters \ No newline at end of file +find = {} # Scan the project directory with the default parameters + +[tool.setuptools.package-data] +monitor = ["distributed/*.yaml"] \ No newline at end of file diff --git "a/debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" "b/debug/accuracy_tools/monitor/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" similarity index 100% rename from "debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" rename to "debug/accuracy_tools/monitor/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" diff --git a/debug/accuracy_tools/atat/README.md b/debug/accuracy_tools/msprobe/README.md similarity index 62% rename from debug/accuracy_tools/atat/README.md rename to debug/accuracy_tools/msprobe/README.md index de7d74f8f29cef077bec8b665afcc537974df774..663dde17e1af734b7f3c499251a5cec91f848a36 100644 --- a/debug/accuracy_tools/atat/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -1,11 +1,26 @@ # MindStudio精度调试工具 -MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat,是MindStudio Training Tools工具链下精度调试部分的工具包。主要包括精度预检和精度比对等子工具,当前适配场景包括PyTorch和MindSpore。 +MindStudio精度调试工具(MindStudio Probe),简称msprobe,是MindStudio Training Tools工具链下精度调试部分的工具包。主要包括精度预检和精度比对等子工具,当前适配场景包括PyTorch和MindSpore。 ## 工具安装 -精度工具合一软件包名称:`ascend_training_accuracy_tools-{version}-py3-none-any.whl` +精度工具合一软件包名称:`mindstudio_probe-{version}-py3-none-any.whl` +### pip安装 + ```shell + pip install mindstudio-probe + ``` +使用`pip install mindstudio-probe==版本号`可安装指定版本的包。 + +pip命令会自动安装最新的包及其配套依赖。 + +提示如下信息则表示安装成功。 + +```bash +Successfully installed mindstudio_probe-{version} +``` + +### 下载whl包安装 1. 使用pip命令安装numpy、openpyxl、pandas、PyYAML、rich、torch、tqdm依赖。 若环境中已安装部分依赖,不需要重复安装。 @@ -16,6 +31,7 @@ MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat, | 版本 | 发布日期 | 支持PyTorch版本 | 下载链接 | 校验码 | | ----- | ---------- | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | + | 1.0.1 | 2024-07-25 | 2.0/2.1/2.2 | [mindstudio_probe-1.0.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.1-py3-none-any.whl) | b699e224e4d4e3bcf9412c54fa858a1ee370f0d7a2bc69cb3f1273ac14a6dc82 | | 1.0 | 2024-07-09 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/1.0/ascend_training_accuracy_tools-1.0-py3-none-any.whl) | 5016dfe886c5d340ec6f60a959673355855f313c91f100680da814efb49f8e81 | | 0.0.3 | 2024-06-11 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-0.0.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/0.0/ascend_training_accuracy_tools-0.0.3-py3-none-any.whl) | f46d9714704859e2d67861a65bbb3c76b0a250cf6e238b978b5b959ab1fe125a | | 0.0.2 | 2024-05-23 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-0.0.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/0.0/ascend_training_accuracy_tools-0.0.2-py3-none-any.whl) | 2e35809bde559e9c4d2f16a02ccde779ed9e436bb65fded0b7ebaf6ac2c88d93 | @@ -43,25 +59,79 @@ MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat, 4. 执行如下命令进行安装。 ```bash - pip3 install ./ascend_training_accuracy_tools-{version}-py3-none-any.whl + pip3 install ./mindstudio_probe-{version}-py3-none-any.whl ``` 若为覆盖安装,请在命令行末尾增加“--force-reinstall”参数强制安装,例如: ```bash - pip3 install ./ascend_training_accuracy_tools-{version}-py3-none-any.whl --force-reinstall + pip3 install ./mindstudio_probe-{version}-py3-none-any.whl --force-reinstall ``` 提示如下信息则表示安装成功。 ```bash - Successfully installed ascend_training_accuracy_tools-{version} + Successfully installed mindstudio_probe-{version} + ``` + +### 从源码安装 +1. 克隆或者下载项目源代码 + + ```shell + git clone https://gitee.com/ascend/mstt.git + cd debug/accuracy_tools + ``` + +2. 安装setuptools和wheel + + ```shell + pip install setuptools wheel + ``` + +3. 安装msprobe + + ```shell + python setup.py install + ``` + 提示出现如下信息则表示源码安装成功。 + ```shell + Finished processing dependencies for mindstudio-probe=={version} ``` +### 查看msprobe工具信息 + +执行如下命令查看msprobe工具信息。 + +```bash +pip show mindstudio-probe +``` + +输出结果如下示例: + +```bash +Name: mindstudio-probe +Version: 1.0 +Summary: This is a pytorch precision comparison tools +Home-page: +Author: +Author-email: +License: +Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages +Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel +Required-by: +``` + +关键字段含义: + +- Name:工具名称。 +- Version:工具版本号。 +- Summary:工具概述。 +- Location:工具安装路径。 +- Requires:工具依赖。 ## 工具使用 -安装atat工具后,可以按照如下思路选择合适的子工具进行精度调试: +安装msprobe工具后,可以按照如下思路选择合适的子工具进行精度调试: 1. 判断框架场景。 @@ -107,32 +177,36 @@ MindStudio精度调试工具(ascend_training_accuracy_tools),简称atat, MindSpore场景:暂不支持。 -上述流程中的工具均为atat工具的子工具,使用相同的命令行,格式如下: +7. TorchAir场景的数据采集和精度比对。 + + 仅支持PyTorch场景,详见[TorchAir训练场景-整网算子精度比对](./pytorch/doc/torchair_compare.md)和[TorchAir训练场景Dump案例](./pytorch/doc/torchair_dump_sample.md)。 + +上述流程中的工具均为msprobe工具的子工具,使用相同的命令行,格式如下: 精度预检工具 ```bash -atat -f run_ut [-h] +msprobe -f run_ut [-h] ``` ```bash -atat -f multi_run_ut [-h] +msprobe -f multi_run_ut [-h] ``` ```bash -atat -f api_precision_compare [-h] +msprobe -f api_precision_compare [-h] ``` 溢出解析工具 ```bash -atat -f run_overflow_check [-h] +msprobe -f run_overflow_check [-h] ``` 数据解析工具 ```bash -atat -f parse [-h] +msprobe -f parse [-h] ``` | 参数 | 说明 | diff --git a/debug/accuracy_tools/atat/mindspore/dump/__init__.py b/debug/accuracy_tools/msprobe/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/mindspore/dump/__init__.py rename to debug/accuracy_tools/msprobe/__init__.py diff --git a/debug/accuracy_tools/atat/config/README.md b/debug/accuracy_tools/msprobe/config/README.md similarity index 89% rename from debug/accuracy_tools/atat/config/README.md rename to debug/accuracy_tools/msprobe/config/README.md index a998704993accefa6f167c6fac6ec18218cca211..7d11a3652539d457736e1c8d5be62cabb558422d 100644 --- a/debug/accuracy_tools/atat/config/README.md +++ b/debug/accuracy_tools/msprobe/config/README.md @@ -2,13 +2,38 @@ 当前配置文件主要为PrecisionDebugger接口执行dump或无标杆比对操作时调用的配置,当PrecisionDebugger接口未指定该配置文件时,使用该文件的默认配置。配置文件详见[config.json](./config.json)。 +当在环境上安装msprobe工具后,config.json文件位置可通过如下方式查找: + +查找msprobe工具安装路径。 + +``` +pip show mindstudio-probe +``` + +输出结果如下示例: + +``` +Name: mindstudio-probe +Version: 1.0 +Summary: This is a pytorch precision comparison tools +Home-page: +Author: +Author-email: +License: +Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages +Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel +Required-by: +``` + +Location字段为msprobe工具的安装路径,那么config.json文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/config + ## 参数说明 ### **通用配置参数** | 参数名 | 说明 | 是否必选 | | ----------------- | ------------------------------------------------------------ | -------- | -| task | dump的任务类型,str类型。可取值"free_benchmark"(无标杆比对,仅PyTorch场景支持)、"statistics"(仅dump API统计信息,默认值)、"tensor"(dump API统计信息和完全复刻整网的API运行情况的真实数据)、"overflow_check"(溢出检测)。配置示例:"task": "tensor"。根据task参数取值的不同,可以配置不同场景参数,详见:“**task配置为free_benchmark**”,“**task配置为statistics**”,“**task配置为tensor**”,“**task配置为overflow_check**”。 | 否 | +| task | dump的任务类型,str类型。可取值:
"free_benchmark"(无标杆比对,仅PyTorch场景支持)。
"statistics"(仅dump API统计信息,默认值)。
"tensor"(dump API统计信息和完全复刻整网的API运行情况的真实数据)。
"overflow_check"(溢出检测,仅PyTorch和MindSpore静态图场景支持)。
"run_ut"(精度预检配置,仅PyTorch场景支持)。
配置示例:"task": "tensor"。
根据task参数取值的不同,可以配置不同场景参数,详见:“**task配置为free_benchmark**”,“**task配置为statistics**”,“**task配置为tensor**”,“**task配置为overflow_check**”,“**task配置为run_ut**”。 | 否 | | dump_path | 设置dump数据目录路径,str类型。配置示例:"dump_path": "./dump_path"。MindSpore场景仅支持绝对路径。 | 是 | | rank | 指定对某张卡上的数据进行dump,list[int]类型,默认未配置(表示dump所有卡的数据),应配置为大于等于0的整数,且须配置实际可用的Rank ID。配置示例:"rank": [1]。
对于PyTorch场景,Rank ID从0开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的Rank ID,则dump数据为空,比如当前环境Rank ID为0到7,实际训练运行0到3卡,此时若配置Rank ID为4或不存在的10等其他值,此时dump数据为空。
对于MindSpore场景,所有节点的Rank ID均从0开始计数,最大取值为每个节点可用卡总数-1,config.json配置一次rank参数对所有节点同时生效。 | 否 | | step | 指定dump某个step的数据,list[int]类型。默认未配置,表示dump所有step数据。dump特定step时,须指定为训练脚本中存在的step。step为list格式,可配置逐个step,例如:"step": [0,1,2]。 | 否 | @@ -85,6 +110,18 @@ task配置为free_benchmark时,开启**无标杆比对**,在NPU环境下通 | overflow_nums | 控制溢出次数,int类型,仅PyTorch场景支持,表示第N次溢出时,停止训练,过程中检测到溢出API对应kernel数据均dump。配置示例:"overflow_nums": 3。默认为1,即检测到1次溢出,训练停止,配置为-1时,表示持续检测溢出直到训练结束。 | 否 | | check_mode | MindSpore场景kernel级别的溢出检测,str类型,可取值"aicore"(开启AI Core的溢出检测)、"atomic"(开启Atomic的溢出检测)、"all"(开启AI Core和Atomic的溢出检测,默认值)。配置示例"check_mode": "aicore"。 | 否 | +### task配置为run_ut + +仅PyTorch场景支持。 + +| 参数名称 | 说明 | 是否必选 | +| --------------- | ------------------------------------------------------------ | -------- | +| white_list | API dump白名单,仅对指定的API进行dump。配置示例:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 | +| black_list | API dump黑名单,被指定的API不进行dump。配置示例:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 | +| error_data_path | 配置保存精度未达标的API输入输出数据路径,默认为当前路径。配置示例"error_data_path": "./"。 | 否 | + +说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。 + ## 配置示例 以下示例包含当前支持的所有场景可配置的完整参数。 @@ -180,6 +217,27 @@ task配置为free_benchmark时,开启**无标杆比对**,在NPU环境下通 } ``` +### PyTorch场景task配置为run_ut + +```json +{ + "task": "run_ut", + "dump_path": "/home/data_dump", + "rank": [], + "step": [], + "level": "L1", + "seed": 1234, + "is_deterministic": false, + "enable_dataloader": false, + + "run_ut": { + "white_list": [], + "black_list": [], + "error_data_path": "./" + } +} +``` + ### MindSpore场景task配置为statistics ```json @@ -394,4 +452,4 @@ train_loader = torch.utils.data.DataLoader( 关闭dropout: -在使用from atat.pytorch import PrecisionDebugger后,工具会自动将torch.nn.functional.dropout、torch.nn.functional.dropout2d、torch.nn.functional.dropout3d、torch.nn.Dropout、torch.nn.Dropout2d、torch.nn.Dropout3d的接口参数p置为0。 +在使用from msprobe.pytorch import PrecisionDebugger后,工具会自动将torch.nn.functional.dropout、torch.nn.functional.dropout2d、torch.nn.functional.dropout3d、torch.nn.Dropout、torch.nn.Dropout2d、torch.nn.Dropout3d的接口参数p置为0。 diff --git a/debug/accuracy_tools/atat/config/config.json b/debug/accuracy_tools/msprobe/config/config.json similarity index 80% rename from debug/accuracy_tools/atat/config/config.json rename to debug/accuracy_tools/msprobe/config/config.json index 70a630a40af1fbd827d060d4783071415d2cb610..f3191ad985c95121ca314146ef96e14dafe7eb09 100644 --- a/debug/accuracy_tools/atat/config/config.json +++ b/debug/accuracy_tools/msprobe/config/config.json @@ -7,6 +7,7 @@ "seed": 1234, "is_deterministic": false, "enable_dataloader": false, + "enable_step_auto_dump": false, "acl_config": "", "tensor": { "scope": [], @@ -24,5 +25,10 @@ "overflow_check": { "overflow_nums": 1, "check_mode":"all" + }, + "run_ut": { + "white_list": [], + "black_list": [], + "error_data_path": "./" } } \ No newline at end of file diff --git a/debug/accuracy_tools/atat/config/img/free_benchmark.png b/debug/accuracy_tools/msprobe/config/img/free_benchmark.png similarity index 100% rename from debug/accuracy_tools/atat/config/img/free_benchmark.png rename to debug/accuracy_tools/msprobe/config/img/free_benchmark.png diff --git a/debug/accuracy_tools/atat/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py similarity index 70% rename from debug/accuracy_tools/atat/core/common/const.py rename to debug/accuracy_tools/msprobe/core/common/const.py index dea829c3ffad125d8c2ac98d97215db37e895b7e..84c5ac1abe4fe44402caf3a1877d99f582626111 100644 --- a/debug/accuracy_tools/atat/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -1,7 +1,9 @@ import os import stat + import numpy as np + class Const: """ Class for const @@ -15,6 +17,17 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' + PRIMITIVE_PREFIX = 'Primitive' + DEFAULT_LIST = [] + DEFAULT_PATH = './' + WHITE_LIST = 'white_list' + BLACK_LIST = 'black_list' + DUMP_TENSOR_DATA = 'dump_tensor_data' + NONE = None + THREE_SEGMENT = 3 + FOUR_SEGMENT = 4 + SIX_SEGMENT = 6 + SEVEN_SEGMENT = 7 # dump mode ALL = "all" @@ -25,6 +38,8 @@ class Const: API_LIST = "api_list" API_STACK = "api_stack" DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + AUTO = "auto" + ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF] SUMMARY = "summary" MD5 = "md5" SUMMARY_MODE = [ALL, SUMMARY, MD5] @@ -35,8 +50,10 @@ class Const: PKL_SUFFIX = ".pkl" NUMPY_SUFFIX = ".npy" + PT_SUFFIX = ".pt" ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024 TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024 + ONE_MB = 1048576 # 1 * 1024 * 1024 FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' DISTRIBUTED_PREFIX_LENGTH = 60 # env dump path @@ -52,13 +69,20 @@ class Const: ENV_ENABLE = "1" ENV_DISABLE = "0" MAX_SEED_VALUE = 4294967295 # 2**32 - 1 - TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] - LEVEL_LIST = ["L0", "L1", "L2", "mix"] STATISTICS = "statistics" TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" FREE_BENCHMARK = "free_benchmark" + RUN_UT = "run_ut" + GRAD_PROBE = "grad_probe" + TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE] + LEVEL_L0 = "L0" + LEVEL_L1 = "L1" + LEVEL_L2 = "L2" + LEVEL_MIX = "mix" + LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX] ATTR_NAME_PREFIX = "wrap_" + ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX) KERNEL_DUMP = "kernel_dump" DATA = "data" PT_FRAMEWORK = "pytorch" @@ -69,13 +93,17 @@ class Const: BOOL_TYPE = [bool, np.uint8] INT_TYPE = [np.int32, np.int64] NPU = 'NPU' + NPU_LOWERCASE = 'npu' + CPU_LOWERCASE = 'cpu' + CUDA_LOWERCASE = 'cuda' DISTRIBUTED = 'Distributed' - + INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single" + "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all", + "all_gather_into_tensor", "reduce_scatter_tensor" ] - + CONVERT = { "int32_to_int64": ["torch.int32", "torch.int64"], } @@ -84,6 +112,7 @@ class Const: "int32_to_int64": ["cross_entropy"] } + class CompareConst: """ Class for compare module const @@ -123,7 +152,12 @@ class CompareConst: NPU_MD5 = "NPU MD5" BENCH_MD5 = "BENCH MD5" RESULT = "Result" - + MAGNITUDE = 0.5 + OP_NAME = "op_name" + INPUT_STRUCT = "input_struct" + OUTPUT_STRUCT = "output_struct" + SUMMARY = "summary" + COMPARE_RESULT_HEADER = [ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, @@ -160,6 +194,7 @@ class CompareConst: WARNING = 'Warning' ERROR = 'error' SKIP = 'SKIP' + N_A = 'N/A' BFLOAT16_MIN = -3.3895313892515355e+38 BFLOAT16_MAX = 3.3895313892515355e+38 BFLOAT16_EPS = 3.90625e-3 # 2 ** -8 @@ -167,6 +202,7 @@ class CompareConst: # accuracy standards COS_THRESHOLD = 0.99 MAX_ABS_ERR_THRESHOLD = 0.001 + MAX_RELATIVE_ERR_THRESHOLD = 0.001 COS_MAX_THRESHOLD = 0.9 MAX_ABS_ERR_MAX_THRESHOLD = 1 ACCURACY_CHECK_YES = "Yes" @@ -183,6 +219,10 @@ class CompareConst: RED = "FFFF0000" YELLOW = "FFFF00" BLUE = "0000FF" + + # run_ut const + MAX_TOKENS = 65536 + SPECIAL_SPARSE_MOED = 4 # highlight rules const OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf'] @@ -195,6 +235,20 @@ class CompareConst: MAX_RELATIVE_OUT_RED = 0.5 MAX_RELATIVE_OUT_YELLOW = 0.1 MAX_RELATIVE_IN_YELLOW = 0.01 + MS_GRAPH_BASE = { + NPU_NAME: None, BENCH_NAME: None, NPU_DTYPE: None, BENCH_DTYPE: None, NPU_SHAPE: None, BENCH_SHAPE: None, + NPU_MAX: None, NPU_MIN: None, NPU_MEAN: None, NPU_NORM: None, BENCH_MAX: None, BENCH_MIN: None, + BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: '' + } + MS_GRAPH_NPY = { + COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, + FIVE_THOUSANDTHS_ERR_RATIO: None + } + MS_GRAPH_STATISTIC = { + MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None, + MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None + } + class FileCheckConst: """ @@ -232,6 +286,7 @@ class FileCheckConst: YAML_SUFFIX: MAX_YAML_SIZE } + class OverflowConst: """ Class for Overflow @@ -239,3 +294,44 @@ class OverflowConst: OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" OVERFLOW_ORIGINAL_MODE = 0 OVERFLOW_DEBUG_MODE = 1 + +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" + + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + #detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + #result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + +class MsgConst: + """ + Class for log messages const + """ + CLEAR_SYMBOL = "\033[K" + LEVEL = ["INFO", "WARNING", "ERROR"] + SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] + + +class GraphMode: + NPY_MODE = "NPY_MODE" + STATISTIC_MODE = "STATISTIC_MODE" + ERROR_MODE = "ERROR_MODE" diff --git a/debug/accuracy_tools/atat/core/common/exceptions.py b/debug/accuracy_tools/msprobe/core/common/exceptions.py similarity index 52% rename from debug/accuracy_tools/atat/core/common/exceptions.py rename to debug/accuracy_tools/msprobe/core/common/exceptions.py index f6b7c19ba32df57513562436c43518cf4a40a7cc..ea61f8cd58fe057ba6836dd1ed368d52adedeb18 100644 --- a/debug/accuracy_tools/atat/core/common/exceptions.py +++ b/debug/accuracy_tools/msprobe/core/common/exceptions.py @@ -1,19 +1,20 @@ class CodedException(Exception): def __init__(self, code, error_info=''): super().__init__() + self.code = code self.error_info = self.err_strs.get(code) + error_info def __str__(self): return self.error_info -class MsaccException(CodedException): +class MsprobeException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 err_strs = { - INVALID_PARAM_ERROR: "[msacc] 无效参数: ", - OVERFLOW_NUMS_ERROR: "[msacc] 超过预设溢出次数 当前溢出次数:" + INVALID_PARAM_ERROR: "[msprobe] 无效参数: ", + OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:" } @@ -26,12 +27,12 @@ class FileCheckException(CodedException): FILE_TOO_LARGE_ERROR = 5 err_strs = { - SOFT_LINK_ERROR: "[msacc] 检测到软链接: ", - FILE_PERMISSION_ERROR: "[msacc] 文件权限错误: ", - INVALID_FILE_ERROR: "[msacc] 无效文件: ", - ILLEGAL_PATH_ERROR: "[msacc] 非法文件路径: ", - ILLEGAL_PARAM_ERROR: "[msacc] 非法打开方式: ", - FILE_TOO_LARGE_ERROR: "[msacc] 文件过大: " + SOFT_LINK_ERROR: "[msprobe] 检测到软链接: ", + FILE_PERMISSION_ERROR: "[msprobe] 文件权限错误: ", + INVALID_FILE_ERROR: "[msprobe] 无效文件: ", + ILLEGAL_PATH_ERROR: "[msprobe] 非法文件路径: ", + ILLEGAL_PARAM_ERROR: "[msprobe] 非法打开方式: ", + FILE_TOO_LARGE_ERROR: "[msprobe] 文件过大: " } @@ -39,8 +40,8 @@ class ParseJsonException(CodedException): UnexpectedNameStruct = 0 InvalidDumpJson = 1 err_strs = { - UnexpectedNameStruct: "[msacc] Unexpected name in json: ", - InvalidDumpJson: "[msacc] json格式不正确: ", + UnexpectedNameStruct: "[msprobe] Unexpected name in json: ", + InvalidDumpJson: "[msprobe] json格式不正确: ", } @@ -49,23 +50,23 @@ class ScopeException(CodedException): InvalidScope = 1 ArgConflict = 2 err_strs = { - InvalidApiStr: "[msacc] Invalid api_list: ", - InvalidScope: "[msacc] Invalid scope: ", - ArgConflict: "[msacc] Scope and api_list conflict: ", + InvalidApiStr: "[msprobe] Invalid api_list: ", + InvalidScope: "[msprobe] Invalid scope: ", + ArgConflict: "[msprobe] Scope and api_list conflict: ", } class RepairException(CodedException): InvalidRepairType = 0 err_strs = { - InvalidRepairType: "[msacc] Invalid repair_type: " + InvalidRepairType: "[msprobe] Invalid repair_type: " } class StepException(CodedException): InvalidPostProcess = 0 err_strs = { - InvalidPostProcess: "[msacc] 错误的step后处理配置: ", + InvalidPostProcess: "[msprobe] 错误的step后处理配置: ", } @@ -73,8 +74,8 @@ class FreeBenchmarkException(CodedException): UnsupportedType = 0 InvalidGrad = 1 err_strs = { - UnsupportedType: "[msacc] Free benchmark get unsupported type: ", - InvalidGrad: "[msacc] Free benchmark gradient invalid: ", + UnsupportedType: "[msprobe] Free benchmark get unsupported type: ", + InvalidGrad: "[msprobe] Free benchmark gradient invalid: ", } diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0976323b5a760c1e6d250923ab7bfdbc166a0bef --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import csv +import fcntl +import os +import json +import re +import shutil +import yaml +import numpy as np + +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.common.const import FileCheckConst + + +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + path_type: file or dictionary + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + self.is_script = is_script + + @staticmethod + def _check_path_type(path_type): + if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: + logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + check_path_exists(self.file_path) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + check_path_type(self.file_path, self.path_type) + self.check_path_ability() + if self.is_script: + check_path_owner_consistent(self.file_path) + check_path_pattern_vaild(self.file_path) + check_common_file_size(self.file_path) + check_file_suffix(self.file_path, self.file_type) + return self.file_path + + def check_path_ability(self): + if self.ability == FileCheckConst.WRITE_ABLE: + check_path_writability(self.file_path) + if self.ability == FileCheckConst.READ_ABLE: + check_path_readability(self.file_path) + if self.ability == FileCheckConst.READ_WRITE_ABLE: + check_path_readability(self.file_path) + check_path_writability(self.file_path) + + +class FileOpen: + """ + The class for open file by a safe way. + + Attributes: + file_path: The file or dictionary path to be opened. + mode(str): The file open mode + """ + SUPPORT_READ_MODE = ["r", "rb"] + SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] + SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] + + def __init__(self, file_path, mode, encoding='utf-8'): + self.file_path = file_path + self.mode = mode + self.encoding = encoding + self._handle = None + + def __enter__(self): + self.check_file_path() + binary_mode = "b" + if binary_mode not in self.mode: + self._handle = open(self.file_path, self.mode, encoding=self.encoding) + else: + self._handle = open(self.file_path, self.mode) + return self._handle + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._handle: + self._handle.close() + + def check_file_path(self): + support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + if self.mode not in support_mode: + logger.error("File open not support %s mode" % self.mode) + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + self.check_ability_and_owner() + check_path_pattern_vaild(self.file_path) + if os.path.exists(self.file_path): + check_common_file_size(self.file_path) + + def check_ability_and_owner(self): + if self.mode in self.SUPPORT_READ_MODE: + check_path_exists(self.file_path) + check_path_readability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): + check_path_readability(self.file_path) + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + + +def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + logger.error('The file path {} is a soft link.'.format(path)) + raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) + + +def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH + if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + logger.error('The file path length exceeds limit.') + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_exists(path): + if not os.path.exists(path): + logger.error('The file path %s does not exist.' % path) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_readability(path): + if not os.access(path, os.R_OK): + logger.error('The file path %s is not readable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_writability(path): + if not os.access(path, os.W_OK): + logger.error('The file path %s is not writable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_executable(path): + if not os.access(path, os.X_OK): + logger.error('The file path %s is not executable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_other_user_writable(path): + st = os.stat(path) + if st.st_mode & 0o002: + logger.error('The file path %s may be insecure because other users have write permissions. ' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid(): + logger.error('The file path %s may be insecure because is does not belong to you.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_pattern_vaild(path): + if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): + logger.error('The file path %s contains special characters.' % (path)) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_file_size(file_path, max_size): + try: + file_size = os.path.getsize(file_path) + except OSError as os_error: + logger.error(f'Failed to open "{file_path}". {str(os_error)}') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error + if file_size >= max_size: + logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.') + raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) + + +def check_common_file_size(file_path): + if os.path.isfile(file_path): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + if file_path.endswith(suffix): + check_file_size(file_path, max_size) + break + + +def check_file_suffix(file_path, file_suffix): + if file_suffix: + if not file_path.endswith(file_suffix): + logger.error(f"The {file_path} should be a {file_suffix} file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def check_path_type(file_path, file_type): + if file_type == FileCheckConst.FILE: + if not os.path.isfile(file_path): + logger.error(f"The {file_path} should be a file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + if file_type == FileCheckConst.DIR: + if not os.path.isdir(file_path): + logger.error(f"The {file_path} should be a dictionary!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def make_dir(dir_path): + dir_path = os.path.realpath(dir_path) + check_path_before_create(dir_path) + if os.path.isdir(dir_path): + return + try: + os.mkdir(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) + except OSError as ex: + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + f"Failed to create {dir_path}. " + f"Please check the path permission or disk space. {str(ex)}") from ex + file_check = FileChecker(dir_path, FileCheckConst.DIR) + file_check.common_check() + + +def create_directory(dir_path): + """ + Function Description: + creating a safe directory with specified permissions + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + dir_path = os.path.realpath(dir_path) + check_path_before_create(dir_path) + parent_dir = os.path.dirname(dir_path) + if not os.path.isdir(parent_dir): + create_directory(parent_dir) + make_dir(dir_path) + + +def check_path_before_create(path): + if path_len_exceeds_limit(path): + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.') + + if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + 'The file path {} contains special characters.'.format(path)) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + else: + path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) + path_checker.common_check() + + +def change_mode(path, mode): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, mode) + except PermissionError as ex: + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, + 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + + +def path_len_exceeds_limit(file_path): + return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + + +def check_file_type(path): + """ + Function Description: + determine if it is a file or a directory + Parameter: + path: path + Exception Description: + when neither a file nor a directory throw exception + """ + if os.path.isdir(path): + return FileCheckConst.DIR + elif os.path.isfile(path): + return FileCheckConst.FILE + else: + logger.error('Neither a file nor a directory.') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def load_yaml(yaml_path): + path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) + checked_path = path_checker.common_check() + try: + with FileOpen(checked_path, "r") as f: + yaml_data = yaml.safe_load(f) + except Exception as e: + logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") + raise RuntimeError(f"Load yaml file {checked_path} failed.") from e + return yaml_data + + +def load_npy(filepath, enable_pickle=False): + check_file_or_directory_path(filepath) + try: + npy = np.load(filepath, allow_pickle=enable_pickle) + except Exception as e: + logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Load numpy file {filepath} failed.") from e + return npy + + +def load_json(json_path): + try: + with FileOpen(json_path, "r") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'load json file "{os.path.basename(json_path)}" failed.') + raise RuntimeError(f"Load json file {json_path} failed.") from e + return data + + +def save_json(json_path, data, indent=None): + json_path = os.path.realpath(json_path) + check_path_before_create(json_path) + try: + with FileOpen(json_path, 'w') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(data, f, indent=indent) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save json file "{os.path.basename(json_path)}" failed.') + raise RuntimeError(f"Save json file {json_path} failed.") from e + change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def move_file(src_path, dst_path): + check_file_or_directory_path(src_path) + check_path_before_create(dst_path) + try: + shutil.move(src_path, dst_path) + except Exception as e: + logger.error(f"move file {src_path} to {dst_path} failed") + raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e + change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_npy(data, filepath): + filepath = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + np.save(filepath, data) + except Exception as e: + logger.error(f"The numpy file failed to save. Please check the path: {filepath}.") + raise RuntimeError(f"Save numpy file {filepath} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_npy_to_txt(self, data, dst_file='', align=0): + if os.path.exists(dst_file): + self.log.info("Dst file %s exists, will not save new one.", dst_file) + return + shape = data.shape + data = data.flatten() + if align == 0: + align = 1 if len(shape) == 0 else shape[-1] + elif data.size % align != 0: + pad_array = np.zeros((align - data.size % align,)) + data = np.append(data, pad_array) + check_path_before_create(dst_file) + try: + np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + except Exception as e: + self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file) + change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_workbook(workbook, file_path): + """ + 保存工作簿到指定的文件路径 + workbook: 要保存的工作簿对象 + file_path: 文件保存路径 + """ + file_path = os.path.realpath(file_path) + check_path_before_create(file_path) + try: + workbook.save(file_path) + except Exception as e: + logger.error(f'Save result file "{os.path.basename(file_path)}" failed') + raise RuntimeError(f"Save result file {file_path} failed.") from e + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def write_csv(data, filepath, mode="a+"): + file_path = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + with FileOpen(filepath, mode, encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + except Exception as e: + logger.error(f'Save csv file "{os.path.basename(file_path)}" failed') + raise RuntimeError(f"Save csv file {file_path} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def remove_path(path): + if not os.path.exists(path): + return + try: + if os.path.islink(path) or os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + except PermissionError as err: + logger.error("Failed to delete {}. Please check the permission.".format(path)) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err + except Exception as e: + logger.error("Failed to delete {}. Please check.".format(path)) + raise RuntimeError(f"Delete {path} failed.") from e + + +def get_json_contents(file_path): + ops = get_file_content_bytes(file_path) + try: + json_obj = json.loads(ops) + except ValueError as error: + logger.error('Failed to load json.') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from error + if not isinstance(json_obj, dict): + logger.error('Json file content is not a dictionary!') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + return json_obj + + +def get_file_content_bytes(file): + with FileOpen(file, 'rb') as file_handle: + return file_handle.read() diff --git a/debug/accuracy_tools/atat/core/common/log.py b/debug/accuracy_tools/msprobe/core/common/log.py similarity index 100% rename from debug/accuracy_tools/atat/core/common/log.py rename to debug/accuracy_tools/msprobe/core/common/log.py diff --git a/debug/accuracy_tools/atat/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py similarity index 72% rename from debug/accuracy_tools/atat/core/common/utils.py rename to debug/accuracy_tools/msprobe/core/common/utils.py index 088530f3c5c88e4e97cd1eda470b02a6a2176fdf..6795bd83f350e2084dc94991cd48c481f38586aa 100644 --- a/debug/accuracy_tools/atat/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -15,20 +15,21 @@ # limitations under the License. """ import collections +import fcntl import os import re import shutil -import stat import subprocess import time import json +import csv from datetime import datetime, timezone -from pathlib import Path +import yaml import numpy as np -from atat.core.common.file_check import FileOpen, FileChecker -from atat.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst -from atat.core.common.log import logger +from msprobe.core.common.file_utils import FileOpen, FileChecker, change_mode +from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.core.common.log import logger device = collections.namedtuple('device', ['type', 'index']) @@ -60,6 +61,8 @@ class CompareException(Exception): OVER_SIZE_FILE_ERROR = 18 INVALID_SUMMARY_MODE = 19 INVALID_TASK_ERROR = 20 + DETACH_ERROR = 21 + def __init__(self, code, error_info: str = ""): super(CompareException, self).__init__() @@ -70,21 +73,26 @@ class CompareException(Exception): return self.error_info -class DumpException(CompareException): - pass +class PathAddException(Exception): + """ + Class for Path Add Exception + """ + INVALID_FILE_ERROR = 0 + PERMISSION_DENIED_ERROR = 1 + UNEXPECTED_ERROR = 2 + PACKAGE_VERSION_CHECK = 3 + def __init__(self, code, error_info: str = ""): + super(PathAddException, self).__init__() + self.code = code + self.error_info = error_info -def make_dump_path_if_not_exists(dump_path): - if not os.path.exists(dump_path): - try: - Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True) - except OSError as ex: - logger.error( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex - else: - if not os.path.isdir(dump_path): - logger.error('{} already exists and is not a directory.'.format(dump_path)) + def __str__(self): + return self.error_info + + +class DumpException(CompareException): + pass def check_mode_valid(mode, scope=None, api_list=None): @@ -148,21 +156,24 @@ def check_summary_only_valid(summary_only): return summary_only -def check_compare_param(input_parma, output_path, stack_mode=False, summary_compare=False, md5_compare=False): - if not (isinstance(input_parma, dict) and isinstance(output_path, str)): +def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False): + if not (isinstance(input_param, dict) and isinstance(output_path, str)): logger.error("Invalid input parameters") raise CompareException(CompareException.INVALID_PARAM_ERROR) - check_file_or_directory_path(input_parma.get("npu_json_path"), False) - check_file_or_directory_path(input_parma.get("bench_json_path"), False) - check_file_or_directory_path(input_parma.get("stack_json_path"), False) + + check_file_or_directory_path(input_param.get("npu_json_path"), False) + check_file_or_directory_path(input_param.get("bench_json_path"), False) + check_file_or_directory_path(input_param.get("stack_json_path"), False) if not summary_compare and not md5_compare: - check_file_or_directory_path(input_parma.get("npu_dump_data_dir"), True) - check_file_or_directory_path(input_parma.get("bench_dump_data_dir"), True) + check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True) + check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True) check_file_or_directory_path(output_path, True) - with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ - FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ - FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: - check_json_file(input_parma, npu_json, bench_json, stack_json) + + with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \ + FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \ + FileOpen(input_param.get("stack_json_path"), "r") as stack_json: + check_json_file(input_param, npu_json, bench_json, stack_json) + def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False): @@ -257,6 +268,17 @@ def remove_path(path): raise CompareException(CompareException.INVALID_PATH_ERROR) from err +def move_file(src_path, dst_path): + check_file_or_directory_path(src_path) + check_path_before_create(dst_path) + try: + shutil.move(src_path, dst_path) + except Exception as e: + logger.error(f"move file {src_path} to {dst_path} failed") + raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e + change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def get_dump_data_path(dump_dir): """ Function Description: @@ -279,24 +301,6 @@ def get_dump_data_path(dump_dir): return dump_data_path, file_is_exist -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - if not os.path.exists(dir_path): - try: - os.makedirs(dir_path, mode=0o700) - except OSError as ex: - logger.error( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex - - def execute_command(cmd): """ Function Description: @@ -318,15 +322,6 @@ def execute_command(cmd): raise CompareException(CompareException.INVALID_DATA_ERROR) -def save_numpy_data(file_path, data): - """ - save_numpy_data - """ - if not os.path.exists(os.path.dirname(file_path)): - os.makedirs(os.path.dirname(file_path)) - np.save(file_path, data) - - def parse_value_by_comma(value): """ parse value by comma, like '1,2,4,8' @@ -472,14 +467,14 @@ def md5_find(data): def task_dumppath_get(input_param): - npu_json_path = input_param.get("npu_json_path", None) - bench_json_path = input_param.get("bench_json_path", None) - if not npu_json_path or not bench_json_path: + npu_path = input_param.get("npu_json_path", None) + bench_path = input_param.get("bench_json_path", None) + if not npu_path or not bench_path: logger.error(f"Please check the json path is valid.") raise CompareException(CompareException.INVALID_PATH_ERROR) - with FileOpen(npu_json_path, 'r') as npu_f: + with FileOpen(npu_path, 'r') as npu_f: npu_json_data = json.load(npu_f) - with FileOpen(bench_json_path, 'r') as bench_f: + with FileOpen(bench_path, 'r') as bench_f: bench_json_data = json.load(bench_f) if npu_json_data['task'] != bench_json_data['task']: logger.error(f"Please check the dump task is consistent.") @@ -496,8 +491,8 @@ def task_dumppath_get(input_param): else: logger.error(f"Compare is not required for overflow_check or free_benchmark.") raise CompareException(CompareException.INVALID_TASK_ERROR) - input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir'] - input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir'] + input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) + input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) return summary_compare, md5_compare @@ -514,3 +509,122 @@ def get_header_index(header_name, summary_compare=False): def convert_tuple(data): return data if isinstance(data, tuple) else (data, ) + + +def write_csv(data, filepath, mode="a+"): + exist = os.path.exists(filepath) + with FileOpen(filepath, mode, encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + if not exist: + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def load_npy(filepath, enable_pickle=False): + check_file_or_directory_path(filepath) + try: + npy = np.load(filepath, allow_pickle=enable_pickle) + except Exception as e: + logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Load numpy file {filepath} failed.") from e + return npy + + +def save_npy(data, filepath): + filepath = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + np.save(filepath, data) + except Exception as e: + logger.error(f"The numpy file failed to save. Please check the path: {filepath}.") + raise RuntimeError(f"Save numpy file {filepath} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + +def save_npy_to_txt(self, data, dst_file='', align=0): + if os.path.exists(dst_file): + self.log.info("Dst file %s exists, will not save new one.", dst_file) + return + shape = data.shape + data = data.flatten() + if align == 0: + align = 1 if len(shape) == 0 else shape[-1] + elif data.size % align != 0: + pad_array = np.zeros((align - data.size % align,)) + data = np.append(data, pad_array) + check_path_before_create(dst_file) + try: + np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + except Exception as e: + self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file) + change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) + +def get_json_contents(file_path): + ops = get_file_content_bytes(file_path) + try: + json_obj = json.loads(ops) + except ValueError as error: + logger.error('Failed to load json.') + raise CompareException(CompareException.INVALID_FILE_ERROR) from error + if not isinstance(json_obj, dict): + logger.error('Json file content is not a dictionary!') + raise CompareException(CompareException.INVALID_FILE_ERROR) + return json_obj + + +def get_file_content_bytes(file): + with FileOpen(file, 'rb') as file_handle: + return file_handle.read() + + +def load_yaml(yaml_path): + path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) + checked_path = path_checker.common_check() + try: + with FileOpen(checked_path, "r") as f: + yaml_data = yaml.safe_load(f) + except Exception as e: + logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") + raise RuntimeError(f"Load yaml file {checked_path} failed.") from e + return yaml_data + + +def save_workbook(workbook, file_path): + """ + 保存工作簿到指定的文件路径 + workbook: 要保存的工作簿对象 + file_path: 文件保存路径 + """ + file_path = os.path.realpath(file_path) + check_path_before_create(file_path) + try: + workbook.save(file_path) + except Exception as e: + logger.error(f'Save result file "{os.path.basename(file_path)}" failed') + raise CompareException(CompareException.WRITE_FILE_ERROR) from e + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def load_json(json_path): + try: + with FileOpen(json_path, "r") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'load json file "{os.path.basename(json_path)}" failed.') + raise DumpException(DumpException.WRITE_FILE_ERROR) from e + return data + + +def save_json(json_path, data, indent=None): + json_path = os.path.realpath(json_path) + check_path_before_create(json_path) + try: + with FileOpen(json_path, 'w') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(data, f, indent=indent) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save json file "{os.path.basename(json_path)}" failed.') + raise DumpException(DumpException.WRITE_FILE_ERROR) from e + change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py similarity index 50% rename from debug/accuracy_tools/atat/core/common_config.py rename to debug/accuracy_tools/msprobe/core/common_config.py index e256372ca877be1ca5474dd87c00decdf9d3c1c1..3c92d6e766f568ff012cc3358ec4d0d6064aa1d2 100644 --- a/debug/accuracy_tools/atat/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -1,6 +1,6 @@ -from atat.core.common.const import Const -from atat.core.common.log import logger -from atat.core.common.exceptions import MsaccException +from msprobe.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import MsprobeException class CommonConfig: @@ -14,28 +14,35 @@ class CommonConfig: self.acl_config = json_config.get('acl_config') self.is_deterministic = json_config.get('is_deterministic', False) self.enable_dataloader = json_config.get('enable_dataloader', False) + self.enable_step_auto_dump = json_config.get('enable_step_auto_dump', False) self._check_config() def _check_config(self): if self.task and self.task not in Const.TASK_LIST: - logger.error_log_with_exp( - "task is invalid, it should be one of {}".format(Const.TASK_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("task is invalid, it should be one of {}".format(Const.TASK_LIST), + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if self.rank is not None and not isinstance(self.rank, list): - logger.error_log_with_exp("rank is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("rank is invalid, it should be a list", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if self.step is not None and not isinstance(self.step, list): - logger.error_log_with_exp("step is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("step is invalid, it should be a list", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if self.level and self.level not in Const.LEVEL_LIST: - logger.error_log_with_exp( - "level is invalid, it should be one of {}".format(Const.LEVEL_LIST), MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("level is invalid, it should be one of {}".format(Const.LEVEL_LIST), + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if self.seed is not None and not isinstance(self.seed, int): - logger.error_log_with_exp("seed is invalid, it should be an integer", MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("seed is invalid, it should be an integer", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if not isinstance(self.is_deterministic, bool): - logger.error_log_with_exp( - "is_deterministic is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("is_deterministic is invalid, it should be a boolean", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if not isinstance(self.enable_dataloader, bool): - logger.error_log_with_exp( - "enable_dataloader is invalid, it should be a boolean", MsaccException(MsaccException.INVALID_PARAM_ERROR)) - + logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) + if not isinstance(self.enable_step_auto_dump, bool): + logger.error_log_with_exp("enable_step_auto_dump is invalid, it should be a boolean", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) + class BaseConfig: def __init__(self, json_config): @@ -44,15 +51,17 @@ class BaseConfig: self.data_mode = json_config.get('data_mode') self.backward_input = json_config.get("backward_input") self.file_format = json_config.get("file_format") - self.summary_mode = json_config.get("summary_mode") - self.overflow_num = json_config.get("overflow_num") + self.summary_mode = json_config.get("summary_mode") + self.overflow_nums = json_config.get("overflow_nums") self.check_mode = json_config.get("check_mode") def check_config(self): if self.scope is not None and not isinstance(self.scope, list): - logger.error_log_with_exp("scope is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("scope is invalid, it should be a list", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if self.list is not None and not isinstance(self.list, list): - logger.error_log_with_exp("list is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR)) + logger.error_log_with_exp("list is invalid, it should be a list", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) if self.data_mode is not None and not isinstance(self.data_mode, list): - logger.error_log_with_exp("data_mode is invalid, it should be a list", MsaccException(MsaccException.INVALID_PARAM_ERROR)) - + logger.error_log_with_exp("data_mode is invalid, it should be a list", + MsprobeException(MsprobeException.INVALID_PARAM_ERROR)) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf02b953f6898e753139b3e0801f4dcf3799db1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -0,0 +1,430 @@ + +import os +import re +import numpy as np +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger +from msprobe.core.common.file_utils import check_file_or_directory_path + + +def extract_json(dirname, stack_json=False): + json_path = '' + for fname in os.listdir(dirname): + if fname == "construct.json": + continue + full_path = os.path.join(dirname, fname) + if full_path.endswith('.json'): + json_path = full_path + if not stack_json and 'stack' not in json_path: + break + if stack_json and 'stack' in json_path: + break + + # Provide robustness on invalid directory inputs + if not json_path: + logger.error(f'No file is found in dump dir {dirname}. ') + raise CompareException(CompareException.NO_DUMP_FILE_ERROR) + return json_path + + +def check_and_return_dir_contents(dump_dir, prefix): + """ + check the given dump dir and validate files in dump dir by using the given prefix patterns to build a + pattern: ^{prefix}(?:0|[0-9][1-9]*)?$ + + Args: + dump_dir (str): dump dir + prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only + + Returns: + content [list]: dir contents + Raises: + CompareException: invalid path + ValueError: prefix not match the patterns + + """ + check_regex_prefix_format_valid(prefix) + check_file_or_directory_path(dump_dir, True) + contents = os.listdir(dump_dir) + pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$') + for name in contents: + if not pattern.match(name): + logger.error( + f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump " + f"output. Please check and delete irrelevant files in {dump_dir} and try again." + ) + raise CompareException(CompareException.INVALID_PATH_ERROR) + return contents + + +def rename_api(npu_name, process): + npu_split = npu_name.split(process) + torch_func_index, in_out = npu_split[0], npu_split[1] + torch_func_split = torch_func_index.rsplit(Const.SEP, 2) + torch_func = str(torch_func_split[0]) + str(in_out) + return torch_func + + +def read_op(op_data, op_name): + op_parsed_list = Const.DEFAULT_LIST + if Const.FORWARD in op_name: + if Const.INPUT_ARGS in op_data: + input_item = op_data[Const.INPUT_ARGS] + input_parsed_list = op_item_parse(input_item, op_name + '.input', None) + op_parsed_list = input_parsed_list.copy() + input_parsed_list.clear() + if Const.INPUT_KWARGS in op_data: + kwargs_item = op_data[Const.INPUT_KWARGS] + if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list): + kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None) + op_parsed_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif kwargs_item: + for kwarg in kwargs_item: + kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None) + op_parsed_list += kwarg_parsed_list + kwarg_parsed_list.clear() + if Const.OUTPUT in op_data: + output_item = op_data[Const.OUTPUT] + output_parsed_list = op_item_parse(output_item, op_name + '.output', None) + op_parsed_list += output_parsed_list + output_parsed_list.clear() + if Const.BACKWARD in op_name: + if Const.INPUT in op_data: + input_item = op_data[Const.INPUT] + input_parsed_list = op_item_parse(input_item, op_name + '.input', None) + op_parsed_list = input_parsed_list.copy() + input_parsed_list.clear() + if Const.OUTPUT in op_data: + output_item = op_data[Const.OUTPUT] + output_parsed_list = op_item_parse(output_item, op_name + '.output', None) + op_parsed_list += output_parsed_list + output_parsed_list.clear() + return op_parsed_list + + +def op_item_parse(item, op_name, index, item_list=None, top_bool=True): + if item_list is None: + item_list = [] + if item is None or (isinstance(item, dict) and not item): + if not top_bool: + tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, + 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} + else: + tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, + 'shape': None, 'md5': None, 'data_name': '-1'} + item_list.append(tmp) + return item_list + if index is None: + if isinstance(item, dict): + full_op_name = op_name + '.0' + else: + full_op_name = op_name + else: + full_op_name = op_name + Const.SEP + str(index) + if isinstance(item, dict): + if 'type' not in item: + for kwarg in item: + kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None) + item_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif 'dtype' in item: + parsed_item = item + parsed_item['full_op_name'] = full_op_name + item_list.append(parsed_item) + elif 'type' in item: + parsed_item = {} + if item['type'] == 'torch.Size': + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = 'torch.Size' + parsed_item['shape'] = str(item['value']) + parsed_item['md5'] = None + parsed_item['Max'] = None + parsed_item['Min'] = None + parsed_item['Mean'] = None + parsed_item['Norm'] = None + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + elif item['type'] == 'slice': + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = 'slice' + parsed_item['shape'] = str(np.shape(np.array(item['value']))) + parsed_item['md5'] = None + parsed_item['Max'] = None + parsed_item['Min'] = None + parsed_item['Mean'] = None + parsed_item['Norm'] = None + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + else: + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = str(type(item['value'])) + parsed_item['shape'] = '[]' + parsed_item['md5'] = None + parsed_item['Max'] = item['value'] + parsed_item['Min'] = item['value'] + parsed_item['Mean'] = item['value'] + parsed_item['Norm'] = item['value'] + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + else: + resolve_api_special_parameters(item, full_op_name, item_list) + else: + for j, item_spec in enumerate(item): + op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False) + return item_list + + +def resolve_api_special_parameters(data_dict, full_op_name, item_list): + """ + Function Description: + 解析下面格式的数据, 是api参数的一种特殊格式 + { + "last_hidden_state": { + "type": "torch.Tensor", + "dtype": "torch.bfloat16", + ... + }, + "loss": { + "type": "torch.Tensor", + "dtype": "torch.float32", + ... + } + } + Parameter: + data_dict: 字典格式的数据 + full_op_name: 参数的全名字符串 + item_list: 参数信息集合 + """ + for key, value in data_dict.items(): + if isinstance(value, dict): + parsed_item = value + parts = full_op_name.split(Const.SEP) + parts.insert(-1, key) + full_op_name_new = ".".join(parts) + parsed_item['full_op_name'] = full_op_name_new + item_list.append(parsed_item) + + +def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False): + def get_accuracy_core(n_start, n_len, b_start, b_len, key): + min_len = min(n_len, b_len) + npu_stack_info = n_dict.get("stack_info", None) + bench_stack_info = b_dict.get("stack_info", None) + has_stack = npu_stack_info and bench_stack_info + + all_mode_bool = not (summary_compare or md5_compare) + if all_mode_bool: + npu_data_name = n_dict.get("data_name", None) + bench_data_name = b_dict.get("data_name", None) + + for index in range(min_len): + + n_name = n_dict['op_name'][n_start + index] + b_name = b_dict['op_name'][b_start + index] + n_struct = n_dict[key][index] + b_struct = b_dict[key][index] + err_msg = "" + if md5_compare: + result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], + n_struct[2], b_struct[2], + CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF] + if has_stack and index == 0 and key == "input_struct": + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + result.append(result_item) + continue + + if summary_compare: + result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], + " ", " ", " ", " ", " ", " ", " ", " "] + else: + result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], + " ", " ", " ", " ", " "] + + npu_summary_data = n_dict.get("summary")[n_start + index] + result_item.extend(npu_summary_data) + bench_summary_data = b_dict.get("summary")[b_start + index] + result_item.extend(bench_summary_data) + + if summary_compare: + start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF) + warning_flag = False + for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)): + if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)): + diff = npu_val - bench_val + if bench_val != 0: + relative = str(abs((diff / bench_val) * 100)) + '%' + else: + relative = "N/A" + result_item[start_idx + i] = diff + result_item[start_idx + i + 4] = relative + magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10) + if magnitude_diff > 0.5: + warning_flag = True + else: + result_item[start_idx + i] = CompareConst.NONE + accuracy_check = CompareConst.WARNING if warning_flag else "" + err_msg += "Need double check api accuracy." if warning_flag else "" + for i in range(start_idx, len(result_item)): + if str(result_item[i]) in ('inf', '-inf', 'nan'): + result_item[i] = f'{result_item[i]}\t' + + result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) + result_item.append(err_msg) + if has_stack and index == 0 and key == "input_struct": + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + if all_mode_bool: + result_item.append(npu_data_name[n_start + index]) + + result.append(result_item) + + if n_len > b_len: + for index in range(b_len, n_len): + n_name = n_dict['op_name'][n_start + index] + n_struct = n_dict[key][index] + if md5_compare: + result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, + n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN] + result.append(result_item) + continue + result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, + n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "] + summary_data = n_dict.get("summary")[n_start + index] + result_item.extend(summary_data) + summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))] + result_item.extend(summary_data) + + err_msg = "" + result_item.append(CompareConst.ACCURACY_CHECK_YES) + result_item.append(err_msg) + + if has_stack and index == 0 and key == "input_struct": + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + if all_mode_bool: + result_item.append(npu_data_name[n_start + index]) + + result.append(result_item) + + n_num = len(n_dict['op_name']) + b_num = len(b_dict['op_name']) + n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name]) + b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name]) + n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name]) + b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name]) + n_num_output = n_num - n_num_input - n_num_kwarg + b_num_output = b_num - b_num_input - b_num_kwarg + get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct') + get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct") + get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct') + + +def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): + index_out = 0 + npu_stack_info = n_dict.get("stack_info", None) + bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A + err_msg = CompareConst.NO_BENCH + accuracy_check_res = CompareConst.N_A + for index, n_name in enumerate(n_dict["op_name"]): + if n_name.find("input") != -1: + n_struct = n_dict["input_struct"][index] + else: + n_struct = n_dict["output_struct"][index_out] + index_out += 1 + + result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape] + if md5_compare: + result_item.extend([CompareConst.N_A] * 3) + if npu_stack_info and index == 0: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + result.append(result_item) + continue + if summary_compare: + result_item.extend([CompareConst.N_A] * 8) + else: + result_item.extend([CompareConst.N_A] * 5) + npu_summary_data = n_dict.get("summary")[index] + result_item.extend(npu_summary_data) + bench_summary_data = [CompareConst.N_A] * 4 + result_item.extend(bench_summary_data) + result_item.append(accuracy_check_res) + result_item.append(err_msg) + if npu_stack_info and index == 0: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A: + result_item.extend(["-1"]) + result.append(result_item) + + +def merge_tensor(tensor_list, summary_compare, md5_compare): + op_dict = {} + op_dict["op_name"] = [] + op_dict["input_struct"] = [] + op_dict["kwargs_struct"] = [] + op_dict["output_struct"] = [] + op_dict["summary"] = [] + op_dict["stack_info"] = [] + + all_mode_bool = not (summary_compare or md5_compare) + if all_mode_bool: + op_dict["data_name"] = [] + + for tensor in tensor_list: + if len(tensor) == 2: + op_dict['stack_info'].append(tensor['full_info']) + break + op_dict["op_name"].append(tensor['full_op_name']) + if not md5_compare: + if tensor['full_op_name'].find("input") != -1: + op_dict["input_struct"].append((tensor['dtype'], tensor['shape'])) + elif tensor['full_op_name'].find("kwarg") != -1: + op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'])) + elif tensor['full_op_name'].find("output") != -1: + op_dict["output_struct"].append((tensor['dtype'], tensor['shape'])) + else: + if tensor['full_op_name'].find("input") != -1: + op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + elif tensor['full_op_name'].find("kwarg") != -1: + op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + elif tensor['full_op_name'].find("output") != -1: + op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + + op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']]) + + if all_mode_bool: + op_dict["data_name"].append(tensor['data_name']) + + if not op_dict["kwargs_struct"]: + del op_dict["kwargs_struct"] + return op_dict if op_dict["op_name"] else {} + + +def _compare_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", type=str, + help=" The compare input path, a dict json.", required=True) + parser.add_argument("-o", "--output_path", dest="output_path", type=str, + help=" The compare task result out path.", required=True) + parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true", + help=" Whether to save stack info.", required=False) + parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true", + help=" Whether to give advisor.", required=False) + parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", + help=" Whether to perform a fuzzy match on the api name.", required=False) + parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, + help=" The cell mapping file path.", required=False) + parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, + help=" The api mapping file path.", required=False) + + + + + diff --git a/debug/accuracy_tools/atat/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py similarity index 85% rename from debug/accuracy_tools/atat/core/data_dump/data_collector.py rename to debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index f6a9a70b138f1f58c777c5eacf1168b628de3f38..db437539afeb98050ce59aad87a1e79d98b84085 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -1,11 +1,10 @@ - import os -from atat.core.data_dump.scope import build_scope, ListScope -from atat.core.data_dump.json_writer import DataWriter -from atat.core.common.log import logger -from atat.core.common.const import Const -from atat.core.data_dump.data_processor.factory import DataProcessorFactory +from msprobe.core.data_dump.scope import build_scope, ListScope +from msprobe.core.data_dump.json_writer import DataWriter +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory def build_data_collector(config): @@ -21,7 +20,8 @@ class DataCollector: self.config = config self.data_writer = DataWriter() self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer) - self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None + self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \ + if self.config.framework == Const.PT_FRAMEWORK else None self.module_count = {} if self.config.task == Const.FREE_BENCHMARK: self.scope = build_scope(ListScope, self.config.scope, self.config.list) @@ -35,7 +35,7 @@ class DataCollector: @property def dump_file_path(self): return self.data_writer.dump_file_path - + @staticmethod def check_scope_and_pid(scope, name, pid): return (not scope or scope.check(name)) and pid == os.getpid() @@ -43,10 +43,10 @@ class DataCollector: @staticmethod def is_inplace(module): return getattr(module, "op_is_inplace", False) - + def if_return_forward_new_output(self): return self.data_processor.if_return_forward_new_output() - + def get_forward_new_output(self): return self.data_processor.get_forward_new_output() @@ -88,8 +88,11 @@ class DataCollector: else: data_info = self.data_processor.analyze_forward_inplace(name, module_input_output) if self.config.level == "L2": - return + return self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) + if self.data_processor.stop_run(): + self.handle_data(name, data_info, use_buffer=False) + raise Exception("[msprobe] exit") self.handle_data(name, data_info) def backward_data_collect(self, name, module, pid, module_input_output): @@ -98,6 +101,9 @@ class DataCollector: return data_info = self.data_processor.analyze_backward(name, module, module_input_output) + if self.data_processor.stop_run(): + self.handle_data(name, data_info, use_buffer=False) + raise Exception("[msprobe] exit") self.handle_data(name, data_info) def update_construct(self, name): @@ -105,12 +111,15 @@ class DataCollector: self.data_writer.update_construct({name: self.module_processor.api_parent_node}) self.data_writer.update_construct(self.module_processor.module_node) - def handle_data(self, name, data_info): + def handle_data(self, name, data_info, use_buffer=True): msg = f"msProbe is collecting data on {name}. " if data_info: msg = self.update_data(data_info, msg) logger.info(msg) - self.data_writer.flush_data_when_buffer_is_full() + if use_buffer: + self.data_writer.flush_data_when_buffer_is_full() + else: + self.write_json() def module_count_func(self, name, name_template): module_name = name.split(Const.SEP)[-3] @@ -135,6 +144,6 @@ class DataCollector: def update_dump_paths(self, *args): self.data_writer.update_dump_paths(*args) self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level) - + def update_iter(self, current_iter): self.data_processor.update_iter(current_iter) diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py similarity index 87% rename from debug/accuracy_tools/atat/core/data_dump/data_processor/base.py rename to debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 208c053192c6da674ec9c7f522a0affdf1d091e2..2fbc86b5656c3bcfe14b2fe9fe6bb295451e9466 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -3,9 +3,9 @@ import inspect from dataclasses import dataclass from typing import Tuple, Dict, Optional, Any import numpy as np -from atat.core.common.log import logger -from atat.core.common.utils import convert_tuple -from atat.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.utils import convert_tuple +from msprobe.core.common.const import Const @dataclass @@ -35,11 +35,11 @@ class ModuleBackwardInputsOutputs: @property def grad_input_tuple(self): return convert_tuple(self.grad_input) - + @property def grad_output_tuple(self): - return convert_tuple(self.grad_output) - + return convert_tuple(self.grad_output) + class TensorStatInfo: def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): @@ -53,7 +53,7 @@ class BaseDataProcessor: _recursive_key_stack = [] special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, bool, int, float, str, slice) - + def __init__(self, config, data_writer): self.data_writer = data_writer self.config = config @@ -65,11 +65,11 @@ class BaseDataProcessor: self.current_iter = 0 self._return_forward_new_output = False self._forward_new_output = None - + @property def data_path(self): return self.data_writer.dump_tensor_data_dir - + @staticmethod def analyze_api_call_stack(name): stack_str = [] @@ -87,7 +87,7 @@ class BaseDataProcessor: stack_str.append(stack_line) stack_info_struct = {name: stack_str} return stack_info_struct - + @staticmethod def _convert_numpy_to_builtin(arg): type_mapping = { @@ -103,26 +103,15 @@ class BaseDataProcessor: if isinstance(arg, numpy_type): return builtin_type(arg), type(arg).__name__ return arg, '' - + @staticmethod def _analyze_numpy(value, numpy_type): return {"type": numpy_type, "value": value} - - @staticmethod - def _analyze_builtin(arg): - single_arg = {} - if isinstance(arg, slice): - single_arg.update({"type": "slice"}) - single_arg.update({"value": [arg.start, arg.stop, arg.step]}) - else: - single_arg.update({"type": type(arg).__name__}) - single_arg.update({"value": arg}) - return single_arg - + @classmethod def get_special_types(cls): return cls.special_type - + @classmethod def recursive_apply_transform(cls, args, transform): if isinstance(args, cls.get_special_types()): @@ -142,9 +131,11 @@ class BaseDataProcessor: resutl_dict[k] = cls.recursive_apply_transform(arg, transform) cls._recursive_key_stack.pop() return resutl_dict - else: + elif args is not None: logger.warning(f"Data type {type(args)} is not supported.") return None + else: + return None def if_return_forward_new_output(self): return self._return_forward_new_output @@ -175,13 +166,14 @@ class BaseDataProcessor: return (Const.ALL in self.config.data_mode or forward_backward in self.config.data_mode or input_output in self.config.data_mode) - - def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs): + + def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): pass - + def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} - if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input + # check whether data_mode contains forward or input + if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): api_info_struct[name] = {} self.api_data_category = Const.INPUT args_info_list = self.analyze_element(module_input_output.args_tuple) @@ -190,13 +182,14 @@ class BaseDataProcessor: kwargs_info_list = self.analyze_element(module_input_output.kwargs) api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list - if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output + # check whether data_mode contains forward or output + if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): api_info_struct[name] = api_info_struct.get(name, {}) self.api_data_category = Const.OUTPUT output_info_list = self.analyze_element(module_input_output.output_tuple) api_info_struct[name][Const.OUTPUT] = output_info_list return api_info_struct - + def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): @@ -208,7 +201,7 @@ class BaseDataProcessor: kwargs_info_list = self.analyze_element(module_input_output.kwargs) api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list return api_info_struct - + def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): concat_args = module_input_output.concat_args_and_kwargs() api_info_struct = {} @@ -218,26 +211,29 @@ class BaseDataProcessor: output_info_list = self.analyze_element(concat_args) api_info_struct[name][Const.OUTPUT] = output_info_list return api_info_struct - + def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): api_info_struct = {} - if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): + if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT): api_info_struct[name] = {} - self.api_data_category = Const.OUTPUT + self.api_data_category = Const.INPUT input_info_list = self.analyze_element(module_input_output.grad_input_tuple) - api_info_struct[name][Const.GRAD_INPUT] = input_info_list + api_info_struct[name][Const.INPUT] = input_info_list - if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT): + if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): api_info_struct[name] = api_info_struct.get(name, {}) - self.api_data_category = Const.INPUT + self.api_data_category = Const.OUTPUT output_info_list = self.analyze_element(module_input_output.grad_output_tuple) - api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list + api_info_struct[name][Const.OUTPUT] = output_info_list return api_info_struct def get_save_file_path(self, suffix): - file_format = "pt" if self.config.framework == Const.PT_FRAMEWORK else "npy" + file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + - suffix + Const.SEP + file_format) + suffix + file_format) file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) - return dump_data_name, file_path \ No newline at end of file + return dump_data_name, file_path + + def stop_run(self): + return False diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py similarity index 89% rename from debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py rename to debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py index bcc771f3684aa55a422691bd9dbfff31f07773dd..ad74acdeebaf9df257eca35cc670a6131805758c 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py @@ -1,10 +1,10 @@ -from atat.core.common.const import Const +from msprobe.core.common.const import Const class DataProcessorFactory: _data_processor = {} _module_processor = {} - + @classmethod def register_processor(cls, framework, task, processor_class): key = (framework, task) @@ -13,7 +13,7 @@ class DataProcessorFactory: @classmethod def register_module_processor(cls, framework, processor_class): cls._module_processor[framework] = processor_class - + @classmethod def get_module_processor(cls, framework): processor_class = cls._module_processor.get(framework) @@ -39,7 +39,7 @@ class DataProcessorFactory: TensorDataProcessor as PytorchTensorDataProcessor, OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor, FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor, - KernelDumpDataProcessor as PytorchKernelDumpDataProcessor + KernelDumpDataProcessor as PytorchKernelDumpDataProcessor ) from ....pytorch.module_processer import ModuleProcesser cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor) @@ -47,15 +47,13 @@ class DataProcessorFactory: cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor) cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor) cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor) - cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser) + cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser) elif framework == Const.MS_FRAMEWORK: from .mindspore_processor import ( StatisticsDataProcessor as MindsporeStatisticsDataProcessor, TensorDataProcessor as MindsporeTensorDataProcessor, - OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor, - FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor + OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor ) cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor) cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor) cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor) - cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfb27cda2afa8794943dceb23b4085f0305c50b --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -0,0 +1,206 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import zlib +import mindspore as ms +from mindspore import ops +import numpy as np + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo, + ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs) +from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode, FileCheckConst +from msprobe.mindspore.dump.hook_cell.wrap_functional import load_ops_functions +from msprobe.mindspore.common.utils import convert_bf16_to_fp32 +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.dump.hook_cell.api_registry import api_register + + +class MindsporeDataProcessor(BaseDataProcessor): + mindspore_special_type = tuple([ms.Tensor]) + ops_func, mint_ops_func, _ = load_ops_functions() + + def __init__(self, config, data_writer): + super().__init__(config, data_writer) + self.mindspore_object_key = { + "dtype": self.analyze_dtype_in_kwargs + } + + @staticmethod + def get_md5_for_tensor(x): + x = convert_bf16_to_fp32(x) + tensor_bytes = x.asnumpy().tobytes() + crc32_hash = zlib.crc32(tensor_bytes) + return f"{crc32_hash:08x}" + + @staticmethod + def analyze_dtype_in_kwargs(element): + return {"type": "mindspore.dtype", "value": str(element)} + + @staticmethod + def _analyze_builtin(arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({"type": "slice"}) + # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型 + values = [ + value if not isinstance(value, ms.Tensor) else value.item() + for value in [arg.start, arg.stop, arg.step] + ] + single_arg.update({"value": values}) + else: + single_arg.update({"type": type(arg).__name__}) + single_arg.update({"value": arg}) + return single_arg + + @classmethod + def get_special_types(cls): + return super().get_special_types() + cls.mindspore_special_type + + def get_stat_info(self, data): + tensor_stat = TensorStatInfo() + if data.numel() == 0: + return tensor_stat + elif data.dtype == ms.bool_: + tensor_stat.max = self.mint_ops_func["max"](data).item() + tensor_stat.min = self.mint_ops_func["min"](data).item() + elif not data.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + elif data.dtype == ms.complex64 or data.dtype == ms.complex128: + data_abs = np.abs(data.asnumpy()) + tensor_stat.max = np.max(data_abs) + tensor_stat.min = np.min(data_abs) + tensor_stat.mean = np.mean(data_abs) + tensor_stat.norm = np.linalg.norm(data_abs) + else: + if data.dtype == ms.bfloat16 or not ops.is_floating_point(data): + data = data.to(ms.float32) + api_register.norm_inner_op_set_ori_func() + tensor_stat.max = self.mint_ops_func["max"](data).item() + tensor_stat.min = self.mint_ops_func["min"](data).item() + tensor_stat.mean = self.mint_ops_func["mean"](data).item() + tensor_stat.norm = self.ops_func["norm"](data).item() + api_register.norm_inner_op_set_hook_func() + return tensor_stat + + def analyze_single_element(self, element, suffix_stack): + if suffix_stack and suffix_stack[-1] in self.mindspore_object_key: + return self.mindspore_object_key[suffix_stack[-1]](element) + + converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) + if converted_numpy is not element: + return self._analyze_numpy(converted_numpy, numpy_type) + if isinstance(element, ms.Tensor): + return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) + + if isinstance(element, (bool, int, float, str, slice)): + return self._analyze_builtin(element) + return {} + + def analyze_element(self, element): + return self.recursive_apply_transform(element, self.analyze_single_element) + + def _analyze_tensor(self, tensor, suffix): + tensor_stat = self.get_stat_info(tensor) + tensor_json = { + 'type': 'mindspore.Tensor', + 'dtype': str(tensor.dtype), + 'shape': tensor.shape, + 'Max': tensor_stat.max, + 'Min': tensor_stat.min, + 'Mean': tensor_stat.mean, + 'Norm': tensor_stat.norm + } + if self.config.summary_mode == Const.MD5: + tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_json.update({Const.MD5: tensor_md5}) + return tensor_json + + +class StatisticsDataProcessor(MindsporeDataProcessor): + pass + + +class TensorDataProcessor(MindsporeDataProcessor): + def _analyze_tensor(self, tensor, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + single_arg = super()._analyze_tensor(tensor, suffix) + single_arg.update({"data_name": dump_data_name}) + if not path_len_exceeds_limit(file_path): + tensor = convert_bf16_to_fp32(tensor) + saved_tensor = tensor.asnumpy() + np.save(file_path, saved_tensor) + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + else: + logger.warning(f'The file path {file_path} length exceeds limit.') + return single_arg + + +class OverflowCheckDataProcessor(MindsporeDataProcessor): + __slots__ = ["cached_tensors_and_file_paths"] + + def __init__(self, config, data_writer): + super().__init__(config, data_writer) + self.cached_tensors_and_file_paths = {} + self.real_overflow_dump_times = 0 + self.overflow_nums = config.overflow_nums + + def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): + self.has_overflow = False + api_info_struct = super().analyze_forward(name, module, module_input_output) + self.maybe_save_overflow_data() + return api_info_struct if self.has_overflow else None + + def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): + self.has_overflow = False + api_info_struct = super().analyze_backward(name, module, module_input_output) + self.maybe_save_overflow_data() + return api_info_struct if self.has_overflow else None + + def maybe_save_overflow_data(self): + if self.has_overflow: + for file_path, tensor in self.cached_tensors_and_file_paths.items(): + tensor = convert_bf16_to_fp32(tensor) + np.save(file_path, tensor.asnumpy()) + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + self.real_overflow_dump_times += 1 + self.cached_tensors_and_file_paths = {} + + def stop_run(self): + if self.overflow_nums == -1: + return False + if self.real_overflow_dump_times >= self.overflow_nums: + logger.warning(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}") + return True + return False + + def _analyze_maybe_overflow_tensor(self, tensor_json): + if tensor_json['Max'] is None: + return + if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + self.has_overflow = True + if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + self.has_overflow = True + + def _analyze_tensor(self, tensor, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + if not path_len_exceeds_limit(file_path): + self.cached_tensors_and_file_paths.update({file_path: tensor}) + else: + logger.warning(f'The file path {file_path} length exceeds limit.') + single_arg = super()._analyze_tensor(tensor, suffix) + self._analyze_maybe_overflow_tensor(single_arg) + single_arg.update({"data_name": dump_data_name}) + return single_arg diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py similarity index 87% rename from debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py rename to debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index cf3c5ebe5864ff48d9f2442c020cb3ae99473b50..0986ccf826330f9783a5dde2cc1d07a87c2e57f4 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -5,18 +5,19 @@ from typing import List import numpy as np import torch -from atat.core.common.exceptions import MsaccException -from atat.core.common.file_check import path_len_exceeds_limit, change_mode -from atat.core.common.log import logger -from atat.core.common.const import Const, OverflowConst, FileCheckConst -from atat.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const, OverflowConst, FileCheckConst +from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo -from atat.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow +from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow try: import torch_npu + is_gpu = False except ImportError: - pass + is_gpu = True class PytorchDataProcessor(BaseDataProcessor): @@ -76,6 +77,42 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item() tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item() return tensor_stat + + @staticmethod + def handle_tensor_extremum_nan_inf(tensor, operator): + data_clone = tensor.detach() + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() + + @staticmethod + def _analyze_builtin(arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({"type": "slice"}) + # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型 + values = [ + value if not isinstance(value, torch.Tensor) else value.item() + for value in [arg.start, arg.stop, arg.step] + ] + single_arg.update({"value": values}) + else: + single_arg.update({"type": type(arg).__name__}) + single_arg.update({"value": arg}) + return single_arg + + @staticmethod + def _analyze_torch_size(arg): + return {"type": "torch.Size", "value": list(arg)} @classmethod def get_special_types(cls): @@ -93,14 +130,11 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) if isinstance(element, (bool, int, float, str, slice)): return self._analyze_builtin(element) - return None + return {} def analyze_element(self, element): return self.recursive_apply_transform(element, self.analyze_single_element) - def _analyze_torch_size(arg): - return {"type": "torch.Size", "value": list(arg)} - def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = {} @@ -112,9 +146,17 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_json.update({"Mean": tensor_stat.mean}) tensor_json.update({"Norm": tensor_stat.norm}) tensor_json.update({"requires_grad": tensor.requires_grad}) - if self.config.summary_mode == "md5": + + if tensor_stat.max is not None: + if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") + if tensor_stat.min is not None: + if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") + + if self.config.summary_mode == Const.MD5: tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({"md5": tensor_md5}) + tensor_json.update({Const.MD5: tensor_md5}) return tensor_json @@ -142,7 +184,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): super().__init__(config, data_writer) self.cached_tensors_and_file_paths = {} self.real_overflow_dump_times = 0 - self.overflow_nums = config.overflow_num + self.overflow_nums = config.overflow_nums self.bits_for_overflow = 8 @staticmethod @@ -150,21 +192,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) return overflow_mode == Const.ENV_ENABLE - @staticmethod - def handle_tensor_extremum_nan_inf(data_clone, operator): - data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) - if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): - return float('nan') - finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) - if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(data_no_nan).item() - def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_forward(name, module, module_input_output) @@ -190,7 +217,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): if self.overflow_nums == -1: return if self.real_overflow_dump_times >= self.overflow_nums: - raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times)) + raise MsprobeException(MsprobeException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times)) def check_overflow_npu(self): if self.overflow_debug_mode_enalbe(): @@ -210,16 +237,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): else: torch_npu._C._clear_overflow_npu() - def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): - data_clone = tensor.detach() - if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): + def _analyze_maybe_overflow_tensor(self, tensor_json): + if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()): if tensor_json['Max'] is None: return if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") self.has_overflow = True if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") self.has_overflow = True else: self.has_overflow = self.check_overflow_npu() @@ -233,7 +257,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): else: logger.warning(f'The file path {file_path} length exceeds limit.') single_arg = super()._analyze_tensor(tensor, suffix) - self._analyze_maybe_overflow_tensor(single_arg, tensor) + self._analyze_maybe_overflow_tensor(single_arg) single_arg.update({"data_name": dump_data_name}) return single_arg diff --git a/debug/accuracy_tools/atat/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py similarity index 85% rename from debug/accuracy_tools/atat/core/data_dump/json_writer.py rename to debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 23f37b2342e9bde9d69bb65022a40911d5a54dc4..99cc5f3159ee038133e874109a78b2a7a85d9413 100644 --- a/debug/accuracy_tools/atat/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -4,9 +4,9 @@ import fcntl import json from pathlib import Path -from atat.core.common.file_check import change_mode -from atat.core.common.log import logger -from atat.core.common.const import Const, FileCheckConst +from msprobe.core.common.file_utils import change_mode, FileOpen +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const, FileCheckConst class DataWriter: @@ -30,20 +30,20 @@ class DataWriter: return is_exists = os.path.exists(file_path) append = "a+" if is_exists else "w+" - with os.fdopen( - os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" - ) as csv_file: + with FileOpen(file_path, append) as csv_file: spawn_writer = csv.writer(csv_file) if not is_exists: spawn_writer.writerow(result_header) spawn_writer.writerows([result,]) + is_new_file = not is_exists + if is_new_file: + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) def initialize_json_file(self, **kwargs): kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) - with os.fdopen( - os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' - ) as f: + with FileOpen(self.dump_file_path, 'w') as f: json.dump(kwargs, f) + change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY) if os.path.exists(self.stack_file_path): os.remove(self.stack_file_path) @@ -83,7 +83,7 @@ class DataWriter: def write_data_json(self, file_path): logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") if Path(file_path).exists() and os.path.getsize(file_path) > 0: - with open(file_path, "r+") as f: + with FileOpen(file_path, "r+") as f: fcntl.flock(f, fcntl.LOCK_EX) data_to_write = json.load(f) fcntl.flock(f, fcntl.LOCK_UN) @@ -91,7 +91,7 @@ class DataWriter: self.init_json['data_path'] = self.dump_tensor_data_dir data_to_write = self.init_json data_to_write[Const.DATA].update(self.cache_data[Const.DATA]) - with open(file_path, 'w+') as f: + with FileOpen(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(data_to_write, f, indent=1) fcntl.flock(f, fcntl.LOCK_UN) @@ -99,13 +99,13 @@ class DataWriter: self.cache_data[Const.DATA].clear() def write_stack_info_json(self, file_path): - with open(file_path, 'w+') as f: + with FileOpen(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(self.cache_stack, f, indent=1) fcntl.flock(f, fcntl.LOCK_UN) def write_construct_info_json(self, file_path): - with open(file_path, 'w+') as f: + with FileOpen(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(self.cache_construct, f, indent=1) fcntl.flock(f, fcntl.LOCK_UN) diff --git a/debug/accuracy_tools/atat/core/data_dump/scope.py b/debug/accuracy_tools/msprobe/core/data_dump/scope.py similarity index 98% rename from debug/accuracy_tools/atat/core/data_dump/scope.py rename to debug/accuracy_tools/msprobe/core/data_dump/scope.py index e7114f343fe724ffdd40d4837a09b8417d03a1b0..1d74c3e461ac4b0005e4fdd40ae1fe2c12bb1c4e 100644 --- a/debug/accuracy_tools/atat/core/data_dump/scope.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/scope.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from atat.core.common.exceptions import ScopeException -from atat.core.common.const import Const +from msprobe.core.common.exceptions import ScopeException +from msprobe.core.common.const import Const def build_scope(scope_class, scope=None, api_list=None): diff --git a/debug/accuracy_tools/msprobe/mindspore/__init__.py b/debug/accuracy_tools/msprobe/mindspore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf42d1e39feef616e6eb2cc296a099b0bddfd98 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/__init__.py @@ -0,0 +1 @@ +from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger diff --git a/debug/accuracy_tools/msprobe/mindspore/common/log.py b/debug/accuracy_tools/msprobe/mindspore/common/log.py new file mode 100644 index 0000000000000000000000000000000000000000..ec027c750133ce4aabdac4ed914b4a5c50b2a2f1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/common/log.py @@ -0,0 +1,38 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import time +import sys + +from msprobe.mindspore.common.utils import get_rank_if_initialized +from msprobe.core.common.log import BaseLogger +from msprobe.core.common.exceptions import DistributedNotInitializedError + + +class MindsporeLogger(BaseLogger): + def __init__(self): + super().__init__() + + def get_rank(self): + try: + current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + current_rank = None + + return current_rank + + +logger = MindsporeLogger() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6abf0a1ee8870c6044d7cad95542cb89ebf81d46 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -0,0 +1,44 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +from msprobe.core.common.exceptions import DistributedNotInitializedError + + +def get_rank_if_initialized(): + if ms.communication.GlobalComm.INITED: + return ms.communication.get_rank() + else: + raise DistributedNotInitializedError("mindspore distributed environment is not initialized") + + +def convert_bf16_to_fp32(tensor): + if tensor.dtype == ms.bfloat16: + tensor = tensor.to(ms.float32) + return tensor + + +class MsprobeStep(ms.train.Callback): + + def __init__(self, debugger): + super(MsprobeStep, self).__init__() + self.debugger = debugger + + def on_train_step_begin(self, run_context): + self.debugger.start() + + def on_train_step_end(self, run_context): + self.debugger.stop() + self.debugger.step() diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/__init__.py b/debug/accuracy_tools/msprobe/mindspore/debugger/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/mindspore/overflow_check/__init__.py rename to debug/accuracy_tools/msprobe/mindspore/debugger/__init__.py diff --git a/debug/accuracy_tools/atat/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py similarity index 69% rename from debug/accuracy_tools/atat/mindspore/debugger/debugger_config.py rename to debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 56a4b9bf758197d77ef04874f2865e2136d6f67c..23cb7294b8dc0d64e012f5fac2b863bdfe871bbe 100644 --- a/debug/accuracy_tools/atat/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -1,13 +1,10 @@ import os +from msprobe.core.common.utils import Const +from msprobe.core.common.const import MsConst -class DebuggerConfig: - convert_map = { - "L0": "cell", - "L1": "api", - "L2": 'kernel' - } +class DebuggerConfig: def __init__(self, common_config, task_config): self.dump_path = common_config.dump_path self.task = common_config.task @@ -15,18 +12,22 @@ class DebuggerConfig: self.step = [] if not common_config.step else common_config.step if not common_config.level: common_config.level = "L1" - self.level = DebuggerConfig.convert_map[common_config.level] + self.level = MsConst.TOOL_LEVEL_DICT.get(common_config.level, MsConst.API) + self.level_ori = common_config.level self.list = [] if not task_config.list else task_config.list - self.data_mode = [] if not task_config.data_mode else task_config.data_mode + self.scope = [] if not task_config.scope else task_config.scope + self.data_mode = [] if not task_config.data_mode else task_config.data_mode self.file_format = task_config.file_format + self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums self.check_mode = task_config.check_mode - + self.framework = Const.MS_FRAMEWORK + self.summary_mode = task_config.summary_mode self.check() def check(self): if not self.dump_path: raise Exception("Dump path is empty.") - if not os.path.isabs(self.dump_path): + if self.level_ori != "L1" and not os.path.isabs(self.dump_path): raise Exception("Dump path must be absolute path.") if not self.task: self.task = "statistics" diff --git a/debug/accuracy_tools/atat/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py similarity index 31% rename from debug/accuracy_tools/atat/mindspore/debugger/precision_debugger.py rename to debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 0099074762f0746c1bd8341047f37b3e5fe08855..5475dc3586c35687fec63b51f265ac83c0d33a87 100644 --- a/debug/accuracy_tools/atat/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -1,7 +1,12 @@ import os -from atat.mindspore.ms_config import parse_json_config -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.task_handler_factory import TaskHandlerFactory + +import mindspore as ms + +from msprobe.mindspore.service import Service +from msprobe.mindspore.ms_config import parse_json_config +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.task_handler_factory import TaskHandlerFactory +from msprobe.core.common.const import MsConst class PrecisionDebugger: @@ -12,6 +17,8 @@ class PrecisionDebugger: cls._instance = super().__new__(cls) cls._instance.initialized = False cls._instance.config = None + cls.service = None + cls.first_start = False return cls._instance def __init__(self, config_path=None): @@ -23,10 +30,46 @@ class PrecisionDebugger: self.config = DebuggerConfig(common_config, task_config) self.initialized = True + @staticmethod + def _get_execution_mode(): + if ms.get_context("mode") == ms.GRAPH_MODE: + if ms.context.get_jit_config().get("jit_level") == "O2" or ms.get_context("jit_level") == "O2": + return MsConst.GRAPH_GE_MODE + else: + return MsConst.GRAPH_KBYK_MODE + else: + return MsConst.PYNATIVE_MODE + @classmethod - def start(cls, target=None): + def start(cls): instance = cls._instance if not instance: raise Exception("No instance of PrecisionDebugger found.") - handler = TaskHandlerFactory.create(instance.config) - handler.handle() + + instance.config.execution_mode = instance._get_execution_mode() + if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API: + if not instance.service: + instance.service = Service(instance.config) + instance.service.start() + else: + if not instance.first_start: + handler = TaskHandlerFactory.create(instance.config) + handler.handle() + + instance.first_start = True + + @classmethod + def stop(cls): + instance = cls._instance + if not instance: + raise Exception("PrecisionDebugger instance is not created.") + if instance.service: + instance.service.stop() + + @classmethod + def step(cls): + instance = cls._instance + if not instance: + raise Exception("PrecisionDebugger instance is not created.") + if instance.service: + instance.service.step() diff --git a/debug/accuracy_tools/atat/mindspore/doc/dump.md b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md similarity index 81% rename from debug/accuracy_tools/atat/mindspore/doc/dump.md rename to debug/accuracy_tools/msprobe/mindspore/doc/dump.md index 3321a4da12bcf7ff5f215c3ddfe8fe922159f750..425d0683a268ebdcaf54a4f70b5e448bb1233f3c 100644 --- a/debug/accuracy_tools/atat/mindspore/doc/dump.md +++ b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md @@ -1,8 +1,8 @@ # **精度数据采集** -atat工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。 +msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。 -执行dump操作需要安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 +执行dump操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 ## dump接口介绍 @@ -12,7 +12,7 @@ atat工具主要通过在训练脚本内添加dump接口并启动训练的方式 通过加载dump配置文件的方式来确定dump操作的详细配置。 -可以在from atat.mindspore import PrecisionDebugger和模型初始化之间的任意位置添加该接口。 +可以在from msprobe.mindspore import PrecisionDebugger和模型初始化之间的任意位置添加该接口。 **原型** @@ -43,7 +43,7 @@ debugger.start() ## 示例代码 ```Python -from atat.mindspore import PrecisionDebugger +from msprobe.mindspore import PrecisionDebugger debugger = PrecisionDebugger(config_path="./config.json") # 请勿将以上初始化流程插入到循环代码中 # 下面代码也可以用PrecisionDebugger.start() diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py b/debug/accuracy_tools/msprobe/mindspore/dump/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py rename to debug/accuracy_tools/msprobe/mindspore/dump/__init__.py diff --git a/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py similarity index 91% rename from debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py rename to debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py index a53841189f5a52de74900b9ba4382e0746dfee3a..5c7af45d79060c00ce198f19a589d46bacf1f756 100644 --- a/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py @@ -1,9 +1,9 @@ import os import json -from atat.core.common.utils import make_dump_path_if_not_exists -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.core.common.log import logger -from atat.core.common.file_check import FileOpen +from msprobe.core.common.utils import make_dump_path_if_not_exists +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.core.common.log import logger +from msprobe.core.common.file_check import FileOpen class ApiKbkDump: diff --git a/debug/accuracy_tools/atat/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py similarity index 82% rename from debug/accuracy_tools/atat/mindspore/dump/dump_tool_factory.py rename to debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py index ab534edc243dfd5f44688358fe4ca8edb6a8a12d..2c4579b0e75fe1573f387f696c3d9e4efd4945e3 100644 --- a/debug/accuracy_tools/atat/mindspore/dump/dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py @@ -1,6 +1,6 @@ -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.dump.api_kbk_dump import ApiKbkDump -from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump +from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump class DumpToolFactory: diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..5508416fde0729d5a8cf31333332464783024c07 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py @@ -0,0 +1,104 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore as ms +from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \ + HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP +from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor +from msprobe.core.common.utils import Const + + +class ApiRegistry: + def __init__(self): + self.tensor_ori_attr = {} + self.functional_ori_attr = {} + self.mint_ops_ori_attr = {} + self.mint_func_ops_ori_attr = {} + self.norm_inner_ops_ori_attr = {} + + self.tensor_hook_attr = {} + self.functional_hook_attr = {} + self.mint_ops_hook_attr = {} + self.mint_func_ops_hook_attr = {} + self.norm_inner_ops_hook_attr = {} + + self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"] + + @staticmethod + def store_ori_attr(ori_api_group, api_list, api_ori_attr): + for api in api_list: + if Const.SEP in api: + sub_module_name, sub_op = api.rsplit(Const.SEP, 1) + sub_module = getattr(ori_api_group, sub_module_name) + api_ori_attr[api] = getattr(sub_module, sub_op) + else: + api_ori_attr[api] = getattr(ori_api_group, api) + + @staticmethod + def set_api_attr(api_group, attr_dict): + for api, api_attr in attr_dict.items(): + if Const.SEP in api: + sub_module_name, sub_op = api.rsplit(Const.SEP, 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, api_attr) + else: + setattr(api_group, api, api_attr) + + def norm_inner_op_set_hook_func(self): + self.set_api_attr(ms.ops, self.norm_inner_ops_hook_attr) + + def norm_inner_op_set_ori_func(self): + self.set_api_attr(ms.ops, self.norm_inner_ops_ori_attr) + + def api_set_hook_func(self): + self.set_api_attr(ms.Tensor, self.tensor_hook_attr) + self.set_api_attr(ms.ops, self.functional_hook_attr) + self.set_api_attr(ms.mint, self.mint_ops_hook_attr) + self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_hook_attr) + + def api_set_ori_func(self): + self.set_api_attr(ms.Tensor, self.tensor_ori_attr) + self.set_api_attr(ms.ops, self.functional_ori_attr) + self.set_api_attr(ms.mint, self.mint_ops_ori_attr) + self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_ori_attr) + + def initialize_hook(self, hook): + self.store_ori_attr(ms.Tensor, get_tensor_ops(), self.tensor_ori_attr) + wrap_tensor_ops_and_bind(hook) + for attr_name in dir(HOOKTensor): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): + self.tensor_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKTensor, attr_name) + + functional_ops, mint_ops, mint_func_ops = get_functional_ops() + self.store_ori_attr(ms.ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr) + self.store_ori_attr(ms.ops, functional_ops, self.functional_ori_attr) + self.store_ori_attr(ms.mint, mint_ops, self.mint_ops_ori_attr) + self.store_ori_attr(ms.mint.nn.functional, mint_func_ops, self.mint_func_ops_ori_attr) + setup_hooks(hook) + for attr_name in dir(HOOKFunctionalOP): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): + self.functional_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name) + if attr_name[Const.ATTR_NAME_PREFIX_LEN:] in self.norm_inner_ops: + self.norm_inner_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name) + for attr_name in dir(HOOKMintOP): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): + self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name) + for attr_name in dir(HOOKMintNNFunctionalOP): + if attr_name.startswith(Const.ATTR_NAME_PREFIX): + self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name) + + +api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py new file mode 100644 index 0000000000000000000000000000000000000000..57ed44111ca0bcdf4d8bafbf27a1373e55bb5480 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py @@ -0,0 +1,53 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from collections import defaultdict + +from mindspore import nn +from msprobe.core.common.const import Const + + +class HOOKCell(nn.Cell): + cell_count = defaultdict(int) + g_stop_hook = False + + def __init__(self, build_hook) -> None: + super(HOOKCell, self).__init__() + self.changed_status = False + self.input_kwargs = {} + self.prefix = "" + if not HOOKCell.g_stop_hook: + HOOKCell.g_stop_hook = True + self.changed_status = True + if hasattr(self, "prefix_op_name_"): + self.prefix = self.prefix_op_name_ + + HOOKCell.cell_count[self.prefix] += 1 + self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP + forward_hook, backward_hook = build_hook(self.prefix) + self.register_forward_hook(forward_hook) + self.register_backward_hook(backward_hook) + + # 重载call,加全局标志。 + def __call__(self, *args, **kwargs): + try: + self.input_kwargs = kwargs + out = super(HOOKCell, self).__call__(*args, **kwargs) + except Exception as e: + raise e + finally: + if self.changed_status: + self.changed_status = False + HOOKCell.g_stop_hook = False + return out diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..089f444b6181f0623c8029926c4808ab22ae27ca --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -0,0 +1,925 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# List of ops that register hooks + + +ops: + - adaptive_avg_pool1d + - adaptive_avg_pool2d + - adaptive_avg_pool3d + - adaptive_max_pool1d + - adaptive_max_pool2d + - avg_pool1d + - avg_pool2d + - avg_pool3d + - batch_norm + - bias_add + - ctc_greedy_decoder + - conv1d + - conv2d + - conv3d + - deformable_conv2d + - dense + - dropout + - dropout1d + - dropout2d + - dropout3d + - flatten + - fold + - fractional_max_pool3d + - lp_pool1d + - lp_pool2d + - lrn + - max_pool2d + - max_pool3d + - max_unpool1d + - max_unpool2d + - max_unpool3d + - unfold + - binary_cross_entropy + - binary_cross_entropy_with_logits + - cosine_embedding_loss + - cross_entropy + - ctc_loss + - gaussian_nll_loss + - hinge_embedding_loss + - huber_loss + - kl_div + - l1_loss + - margin_ranking_loss + - mse_loss + - multi_margin_loss + - multilabel_margin_loss + - multilabel_soft_margin_loss + - nll_loss + - smooth_l1_loss + - triplet_margin_loss + - elu + - fast_gelu + - gelu + - glu + - gumbel_softmax + - hardshrink + - hardsigmoid + - hardswish + - hardtanh + - leaky_relu + - log_softmax + - logsigmoid + - mish + - prelu + - relu + - relu6 + - celu + - rrelu + - selu + - sigmoid + - silu + - softmax + - softmin + - softshrink + - softsign + - tanh + - threshold + - cdist + - dist + - pdist + - choice_with_mask + - random_categorical + - log_uniform_candidate_sampler + - uniform_candidate_sampler + - affine_grid + - bounding_box_decode + - bounding_box_encode + - col2im + - check_valid + - crop_and_resize + - grid_sample + - interpolate + - iou + - pad + - padding + - pixel_shuffle + - pixel_unshuffle + - upsample + - abs + - absolute + - accumulate_n + - acos + - arccos + - acosh + - add + - addcdiv + - addcmul + - addmv + - addn + - angle + - arccosh + - arcsin + - arcsinh + - arctan + - arctanh + - arctan2 + - asin + - asinh + - atan + - atan2 + - atanh + - atleast_1d + - atleast_2d + - atleast_3d + - bessel_i0 + - bessel_i0e + - bessel_i1 + - bessel_i1e + - bessel_j0 + - bessel_j1 + - bessel_k0 + - bessel_k0e + - bessel_k1 + - bessel_k1e + - bessel_y0 + - bessel_y1 + - bitwise_and + - bitwise_left_shift + - bitwise_or + - bitwise_right_shift + - bitwise_xor + - ceil + - clamp + - clip + - combinations + - copysign + - cos + - cosh + - cosine_similarity + - cov + - diag_embed + - diff + - deg2rad + - digamma + - div + - divide + - erf + - erfc + - erfinv + - exp + - exp2 + - expm1 + - floor + - floor_div + - floor_mod + - float_power + - fmod + - frac + - gcd + - hypot + - igamma + - igammac + - imag + - i0 + - inv + - invert + - lcm + - ldexp + - lerp + - log + - log2 + - log10 + - log1p + - logaddexp + - logaddexp2 + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - mul + - multiply + - mvlgamma + - neg + - negative + - nextafter + - polar + - polygamma + - positive + - pow + - rad2deg + - ravel + - real + - reciprocal + - remainder + - rot90 + - round + - rsqrt + - sgn + - sign + - signbit + - sin + - sinc + - sinh + - sqrt + - square + - sub + - subtract + - t + - tan + - tanhshrink + - trapz + - tril_indices + - triu_indices + - true_divide + - trunc + - truncate_div + - truncate_mod + - xdivy + - xlogy + - zeta + - all + - amax + - amin + - aminmax + - any + - argmax + - argmin + - cummax + - cummin + - cumprod + - cumsum + - fmax + - histc + - logsumexp + - max + - mean + - median + - min + - norm + - prod + - std + - std_mean + - var + - var_mean + - argsort + - approximate_equal + - equal + - ge + - greater + - greater_equal + - gt + - intopk + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - isreal + - is_complex + - le + - less + - less_equal + - lt + - maximum + - minimum + - msort + - ne + - not_equal + - searchsorted + - topk + - bmm + - addbmm + - addmm + - baddbmm + - addr + - adjoint + - cholesky + - cholesky_solve + - batch_dot + - dot + - eig + - inner + - inverse + - geqrf + - ger + - kron + - lu_solve + - lu_unpack + - matmul + - matrix_solve + - matrix_band_part + - matrix_diag + - matrix_diag_part + - matrix_set_diag + - mm + - mv + - outer + - orgqr + - ormqr + - pinv + - svd + - tensor_dot + - logdet + - slogdet + - qr + - trace + - bartlett_window + - blackman_window + - hamming_window + - hann_window + - kaiser_window + - eye + - fill + - full + - full_like + - linspace + - logspace + - one_hot + - arange + - range + - heaviside + - bernoulli + - gamma + - laplace + - multinomial + - multinomial_with_replacement + - rand + - rand_like + - randint + - randint_like + - randn + - randn_like + - random_gamma + - random_poisson + - randperm + - standard_laplace + - standard_normal + - uniform + - argwhere + - batch_to_space_nd + - bincount + - block_diag + - broadcast_to + - cat + - channel_shuffle + - chunk + - column_stack + - concat + - conj + - count_nonzero + - deepcopy + - diag + - diagflat + - diagonal + - dyn_shape + - dsplit + - dstack + - einsum + - expand + - expand_dims + - flip + - fliplr + - flipud + - gather_d + - gather_elements + - gather_nd + - hsplit + - hstack + - index_add + - index_fill + - index_select + - inplace_add + - inplace_index_add + - inplace_sub + - inplace_update + - masked_fill + - masked_select + - meshgrid + - moveaxis + - movedim + - narrow + - nan_to_num + - nansum + - normal + - nonzero + - population_count + - rank + - repeat_elements + - repeat_interleave + - reshape + - reverse + - reverse_sequence + - roll + - scatter + - scatter_nd + - select + - sequence_mask + - shuffle + - size + - slice + - sort + - space_to_batch_nd + - sparse_segment_mean + - split + - squeeze + - stack + - strided_slice + - sum + - swapaxes + - swapdims + - tensor_scatter_add + - tensor_scatter_div + - tensor_scatter_max + - tensor_scatter_min + - tensor_scatter_mul + - tensor_scatter_sub + - tensor_scatter_elements + - tensor_split + - tile + - tril + - triu + - transpose + - unbind + - unique + - unique_consecutive + - unique_with_pad + - unsorted_segment_max + - unsorted_segment_min + - unsorted_segment_prod + - unsorted_segment_sum + - unsqueeze + - unstack + - view_as_real + - vsplit + - vstack + - where + - cross + - renorm + - is_tensor + - scalar_cast + - scalar_to_tensor + - tuple_to_array + - clip_by_global_norm + - clip_by_value + - assign + - assign_add + - assign_sub + - scatter_add + - scatter_div + - scatter_max + - scatter_min + - scatter_mul + - scatter_nd_add + - scatter_nd_div + - scatter_nd_max + - scatter_nd_min + - scatter_nd_mul + - scatter_nd_sub + - scatter_update + - derivative + - jet + +tensor: + - __abs__ + - __add__ + - __and__ + - __bool__ + - __eq__ + - __ge__ + - __gt__ + - __iadd__ + - __ifloordiv__ + - __imatmul__ + - __imod__ + - __imul__ + - __isub__ + - __le__ + - __lt__ + - __matmul__ + - __mod__ + - __mul__ + - __ne__ + - __neg__ + - __or__ + - __pow__ + - __radd__ + - __rmatmul__ + - __rmod__ + - __rmul__ + - __rpow__ + - __rsub__ + - __sub__ + - __truediv__ + - __xor__ + - abs + - absolute + - acos + - acosh + - add + - addbmm + - addcdiv + - addcmul + - addmm + - addmv + - addr + - all + - amax + - amin + - any + - arccos + - arccosh + - argmax + - angle + - arcsin + - arcsinh + - arctan + - arctanh + - argmin + - argsort + - asin + - asinh + - atan + - atan2 + - atanh + - baddbmm + - bernoulli + - bincount + - bitwise_and + - bitwise_or + - bitwise_xor + - bmm + - bool + - broadcast_to + - ceil + - cholesky_solve + - cholesky + - clamp + - clip + - conj + - copysign + - cos + - cosh + - cross + - cummax + - cummin + - cumprod + - cumsum + - deg2rad + - diag + - diagflat + - diff + - digamma + - div + - divide + - equal + - erf + - erfc + - erfinv + - exp + - expand_as + - expm1 + - flip + - fliplr + - flipud + - float_power + - floor + - fmod + - frac + - gather_elements + - ge + - geqrf + - ger + - greater + - greater_equal + - gt + - half + - hardshrink + - heaviside + - histc + - hypot + - i0 + - igamma + - igammac + - imag + - index_add + - index_fill + - index_put + - index_select + - inner + - int + - inverse + - isclose + - isfinite + - isinf + - isnan + - is_complex + - is_signed + - isneginf + - isposinf + - isreal + - lcm + - ldexp + - le + - lerp + - less + - less_equal + - log + - log10 + - log1p + - log2 + - logaddexp + - logaddexp2 + - logdet + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - logsumexp + - long + - lt + - masked_fill + - masked_scatter + - masked_select + - matmul + - max + - maximum + - mean + - median + - min + - minimum + - moveaxis + - movedim + - msort + - multinomial + - multiply + - mvlgamma + - nan_to_num + - nansum + - narrow + - ne + - neg + - negative + - nelement + - new_ones + - new_zeros + - nextafter + - norm + - nonzero + - not_equal + - ormqr + - permute + - pow + - prod + - qr + - ravel + - real + - reciprocal + - remainder + - renorm + - rad2deg + - tile + - repeat_interleave + - reshape + - reshape + - round + - rot90 + - rsqrt + - sum_to_size + - scatter + - sgn + - short + - sigmoid + - sign + - signbit + - sin + - sinc + - sinh + - slogdet + - sort + - split + - sqrt + - square + - squeeze + - std + - subtract + - subtract + - svd + - swapaxes + - swapdims + - t + - take + - tan + - tanh + - trace + - swapaxes + - tile + - to + - topk + - tril + - tensor_split + - transpose + - true_divide + - trunc + - unbind + - unique_consecutive + - unsqueeze + - var + - view + - where + - xlogy + - from_numpy + - std + - take + - var + - all + - any + - copy + - diagonal + - flatten + - resize + - sum + +mint.ops: + - abs + - absolute_import + - add + - add_ex + - all + - any + - any_ex + - arange + - argmax + - avg_pool2d + - baddbmm + - baddbmm_ex + - batch_norm + - binary_cross_entropy_with_logits + - bitwise_and + - bitwise_or + - bitwise_xor + - bmm + - broadcast_to + - cat + - cat_ex + - ceil + - chunk + - clamp + - conv2d + - conv_transpose2d + - cos + - cross + - cummax + - cummin + - cumsum + - div + - divide + - dropout + - embedding + - eq + - erf + - erfinv + - exp + - flatten + - flip + - flip_ex + - fold + - full + - functional + - gather + - gelu + - greater + - grid_sample + - group_norm + - gt + - index_select + - interpolate + - isclose + - isfinite + - layer_norm + - le + - leaky_relu + - less + - less_equal + - linear + - linspace + - log + - logical_and + - logical_not + - logical_or + - lt + - masked_select + - matmul + - max + - max_pool2d + - maximum + - mean + - mean_ex + - min + - minimum + - mul + - ne + - neg + - negative + - nn + - nonzero + - normal + - one_hot + - ones + - ones_ex + - ones_like + - pad + - permute + - permute_ex + - pow + - prod + - reciprocal + - relu + - remainder + - repeat_interleave + - rsqrt + - scatter + - scatter_add + - searchsorted + - sigmoid + - silu + - sin + - softmax + - softplus + - sort + - split + - sqrt + - sqrt_ex + - square + - stack + - sub + - sub_ex + - sum + - tanh + - tile + - topk + - tril + - triu + - unfold + - unique + - where + - xlogy + - zeros + - zeros_ex + - zeros_like + +mint.nn: + - Dropout + - Embedding + - Fold + - LayerNorm + - Linear + - MaxPool2d + - Unfold + - Upsample + +mint.nn.functional: + - absolute_import + - avg_pool2d + - batch_norm + - batch_norm_ex + - bce_with_logits + - binary_cross_entropy_with_logits + - conv_transpose2d + - dense + - dropout + - embedding + - fold + - gelu + - grid_sample + - group_norm + - interpolate + - layer_norm + - leaky_relu + - linear + - max_pool2d + - max_pool2d_ex + - normal + - one_hot + - one_hot_ext + - pad + - relu + - sigmoid + - silu + - softmax + - softmax_ex + - softplus + - tanh + - unfold diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_functional.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..be3d1bd2545cbdc5d1b03044feed74841d5f2f91 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_functional.py @@ -0,0 +1,94 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import yaml +import mindspore as ms +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell +from msprobe.core.common.utils import Const +from msprobe.core.common.file_check import FileOpen + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") + + +def load_ops_functions(): + ops_func = {f: getattr(ms.ops, f) for f in dir(ms.ops)} + mint_ops_func = {f: getattr(ms.mint, f) for f in dir(ms.mint)} + mint_func_ops_func = {f: getattr(ms.mint.nn.functional, f) for f in dir(ms.mint.nn.functional)} + return ops_func, mint_ops_func, mint_func_ops_func + + +def get_functional_ops(): + ops_func, mint_ops_func, mint_func_ops_func = load_ops_functions() + with FileOpen(yaml_path, 'r') as f: + config = yaml.safe_load(f) + WrapFunctionalOps = config.get("ops") + WrapMintOps = config.get("mint.ops") + WrapMintFunctionalOps = config.get("mint.nn.functional") + return ( + set(WrapFunctionalOps) & set(ops_func.keys()), + set(WrapMintOps) & set(mint_ops_func.keys()), + set(WrapMintFunctionalOps) & set(mint_func_ops_func.keys()) + ) + + +class HOOKFunctionalOP(object): + pass + + +class HOOKMintOP(object): + pass + + +class HOOKMintNNFunctionalOP(object): + pass + + +class FunctionalOPTemplate(HOOKCell): + def __init__(self, op_name, op_dict, prefix, hook): + self.op_name = op_name + self.op_func = op_dict[op_name] + self.prefix_op_name_ = prefix + str(op_name.split(Const.SEP)[-1]) + Const.SEP + super().__init__(hook) + + def construct(self, *args, **kwargs): + if self.op_name.startswith('dropout'): + return args[0] if args else kwargs.get('input') + return self.op_func(*args, **kwargs) + + +def wrap_functional_op(op_name, op_dict, prefix, hook): + def op_template(*args, **kwargs): + return FunctionalOPTemplate(op_name, op_dict, prefix, hook)(*args, **kwargs) + return op_template + + +def wrap_functional_ops_and_bind(ops, op_dict, prefix, hook, hook_class): + for op_name in ops: + if callable(op_dict[op_name]): + setattr(hook_class, Const.ATTR_NAME_PREFIX + op_name, wrap_functional_op(op_name, op_dict, prefix, hook)) + + +def setup_hooks(hook): + functional_ops, mint_ops, mint_func_ops = get_functional_ops() + wrap_functional_ops_and_bind( + functional_ops, {f: getattr(ms.ops, f) for f in dir(ms.ops)}, "Functional.", hook, HOOKFunctionalOP) + wrap_functional_ops_and_bind( + mint_ops, {f: getattr(ms.mint, f) for f in dir(ms.mint)}, "Mint.", hook, HOOKMintOP) + wrap_functional_ops_and_bind( + mint_func_ops, {f: getattr(ms.mint.nn.functional, f) for f in dir(ms.mint.nn.functional)}, "MintFunctional.", hook, HOOKMintNNFunctionalOP) + diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_tensor.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6a9a979dd37851d5d33bc0edf6ee37095e1e7e --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/wrap_tensor.py @@ -0,0 +1,66 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import yaml +import mindspore as ms + +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell +from msprobe.core.common.utils import Const +from msprobe.core.common.file_check import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapTensorOps = yaml.safe_load(f).get('tensor') + +TensorFunc = {} +for f in dir(ms.Tensor): + TensorFunc[f] = getattr(ms.Tensor, f) + + +def get_tensor_ops(): + global WrapTensorOps + _tensor_ops = dir(ms.Tensor) + return set(WrapTensorOps) & set(_tensor_ops) + + +class HOOKTensor(object): + pass + + +class TensorOPTemplate(HOOKCell): + + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Tensor." + str(op_name) + Const.SEP + super().__init__(hook) + + def construct(self, *args, **kwargs): + return TensorFunc[str(self.op_name_)](*args, **kwargs) + + +def wrap_tensor_op(op_name, hook): + def tensor_op_template(*args, **kwargs): + return TensorOPTemplate(op_name, hook)(*args, **kwargs) + + return tensor_op_template + + +def wrap_tensor_ops_and_bind(hook): + _tensor_ops = get_tensor_ops() + for op_name in _tensor_ops: + if callable(TensorFunc[op_name]): + setattr(HOOKTensor, Const.ATTR_NAME_PREFIX + str(op_name), wrap_tensor_op(op_name, hook)) diff --git a/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_graph_dump.py similarity index 92% rename from debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py rename to debug/accuracy_tools/msprobe/mindspore/dump/kernel_graph_dump.py index 190e6bc4d5591f9ed5aa466c2be631fb224fc89b..8320ee0906458734b29b9b911351739fefb77163 100644 --- a/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_graph_dump.py @@ -1,9 +1,9 @@ import os import json -from atat.core.common.utils import make_dump_path_if_not_exists -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.core.common.log import logger -from atat.core.common.file_check import FileOpen +from msprobe.core.common.utils import make_dump_path_if_not_exists +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.core.common.log import logger +from msprobe.core.common.file_check import FileOpen class KernelGraphDump: diff --git a/debug/accuracy_tools/atat/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py similarity index 67% rename from debug/accuracy_tools/atat/mindspore/ms_config.py rename to debug/accuracy_tools/msprobe/mindspore/ms_config.py index 02cead32f1f5fc2b00c47d75ac9d9950a3cd258d..c0ef6bb6c00aab426fd42a11c3bc2436440a4a6a 100644 --- a/debug/accuracy_tools/atat/mindspore/ms_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py @@ -1,6 +1,7 @@ import json -from atat.core.common_config import CommonConfig, BaseConfig -from atat.core.common.file_check import FileOpen +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.const import Const class TensorConfig(BaseConfig): @@ -31,39 +32,43 @@ class StatisticsConfig(BaseConfig): if self.data_mode is not None and len(self.data_mode) > 0: if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]: raise Exception("data_mode must be all, input or output") + if self.summary_mode and self.summary_mode not in ["statistics", "md5"]: + raise Exception("summary_mode is invalid") -class OverflowCheck(BaseConfig): +class OverflowCheckConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) - self.file_format = None - self.check_mode = json_config.get("check_mode") + self.data_mode = ["all"] self._check_config() def _check_config(self): - if self.data_mode is not None and len(self.data_mode) > 0: - if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]: - raise Exception("data_mode must be all, input or output") + if self.overflow_nums is not None and not isinstance(self.overflow_nums, int): + raise Exception("overflow_nums is invalid, it should be an integer") + if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0: + raise Exception("overflow_nums should be -1 or positive integer") if self.check_mode and self.check_mode not in ["all", "aicore", "atomic"]: raise Exception("check_mode is invalid") +TaskDict = { + Const.TENSOR: TensorConfig, + Const.STATISTICS: StatisticsConfig, + Const.OVERFLOW_CHECK: OverflowCheckConfig, +} + + def parse_common_config(json_config): return CommonConfig(json_config) def parse_task_config(task, json_config): - task_map = json_config[task] + task_map = json_config.get(task) if not task_map: task_map = dict() - if task == "tensor": - return TensorConfig(task_map) - elif task == "statistics": - return StatisticsConfig(task_map) - elif task == "overflow_check": - return OverflowCheck(task_map) - else: + if task not in TaskDict: raise Exception("task is invalid.") + return TaskDict.get(task)(task_map) def parse_json_config(json_file_path): @@ -73,6 +78,6 @@ def parse_json_config(json_file_path): json_config = json.load(file) common_config = parse_common_config(json_config) if not common_config.task: - common_config.task = "statistics" + common_config.task = Const.STATISTICS task_config = parse_task_config(common_config.task, json_config) return common_config, task_config diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/__init__.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/__init__.py rename to debug/accuracy_tools/msprobe/mindspore/overflow_check/__init__.py diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py similarity index 89% rename from debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py rename to debug/accuracy_tools/msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py index 7a677eb3c70c583e745785d3b8b988cbbe93e7dd..6640608735d98c17c1b544b58183224a1cd4ba55 100644 --- a/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py +++ b/debug/accuracy_tools/msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py @@ -1,9 +1,9 @@ import os import json -from atat.core.common.utils import make_dump_path_if_not_exists -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.core.common.log import logger -from atat.core.common.file_check import FileOpen +from msprobe.core.common.utils import make_dump_path_if_not_exists +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.core.common.log import logger +from msprobe.core.common.file_check import FileOpen class KernelGraphOverflowCheck: diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py similarity index 81% rename from debug/accuracy_tools/atat/mindspore/overflow_check/overflow_check_tool_factory.py rename to debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py index fe53359be1ba1ecb73fb84138228415f68e1c2ce..d809c714211aa34f41588bba332b55ed808b5376 100644 --- a/debug/accuracy_tools/atat/mindspore/overflow_check/overflow_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/overflow_check/overflow_check_tool_factory.py @@ -1,5 +1,5 @@ -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck class OverflowCheckToolFactory: diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py new file mode 100644 index 0000000000000000000000000000000000000000..50776aaf1097339e7c6d98944db7ddf2d2238c5f --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -0,0 +1,152 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import copy +from pathlib import Path +import functools +from collections import defaultdict + +from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.data_dump.scope import BaseScope +from msprobe.mindspore.common.utils import get_rank_if_initialized +from msprobe.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create +from msprobe.mindspore.common.log import logger +from msprobe.core.common.utils import Const +from msprobe.core.common.exceptions import DistributedNotInitializedError +from msprobe.mindspore.dump.hook_cell.api_registry import api_register +from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell + + +class Service: + def __init__(self, config): + self.model = None + self.config = copy.deepcopy(config) + self.config.level = self.config.level_ori + self.data_collector = build_data_collector(self.config) + self.switch = False + self.current_iter = 0 + self.first_start = True + self.current_rank = None + self.dump_iter_dir = None + self.start_call = False + + def build_hook(self, module_type, name): + def forward_hook(api_or_module_name, module, input, output): + self.data_collector.visit_and_clear_overflow_status(api_or_module_name) + if not self.switch: + return None + if self.data_collector: + module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=module.input_kwargs, output=output) + self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) + if self.data_collector.if_return_forward_new_output(): + return self.data_collector.get_forward_new_output() + del module.input_kwargs + return output + + def backward_hook(api_or_module_name, module, grad_input, grad_output): + self.data_collector.visit_and_clear_overflow_status(api_or_module_name) + if not self.switch: + return + if self.data_collector: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output) + + pid = os.getpid() + forward_name_template = name + Const.FORWARD + backward_name_template = name + Const.BACKWARD + forward_hook = functools.partial(forward_hook, forward_name_template) + backward_hook = functools.partial(backward_hook, backward_name_template) + + def wrap_forward_hook(*args, **kwargs): + return forward_hook(*args, **kwargs) + + def wrap_backward_hook(*args, **kwargs): + return backward_hook(*args, **kwargs) + + return wrap_forward_hook, wrap_backward_hook + + def step(self): + self.current_iter += 1 + self.data_collector.update_iter(self.current_iter) + HOOKCell.cell_count = defaultdict(int) + + def start(self, model=None): + self.model = model + self.start_call = True + logger.info("msprobe: debugger.start() is set successfully") + if self.config.step and self.current_iter > max(self.config.step): + self.stop() + raise Exception("msprobe: exit after iteration {}".format(max(self.config.step))) + if self.config.step and self.current_iter not in self.config.step: + return + if self.first_start: + try: + self.current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + self.current_rank = None + + if self.config.rank and self.current_rank not in self.config.rank: + return + self.register_hook_new() + self.first_start = False + self.switch = True + logger.info(f"Dump switch is turned on at step {self.current_iter}. ") + self.create_dirs() + logger.info(f"Dump data will be saved in {self.dump_iter_dir}.") + + def stop(self): + logger.info("msprobe: debugger.stop() is set successfully. " + "Please set debugger.start() to turn on the dump switch again. ") + if not self.start_call: + logger.error("msprobe: debugger.start() is not set in the current scope.") + raise Exception("debugger.start() is not set in the current scope.") + if self.config.step and self.current_iter not in self.config.step: + return + if self.config.rank and self.current_rank not in self.config.rank: + return + self.switch = False + self.start_call = False + self.data_collector.write_json() + + def create_dirs(self): + check_path_before_create(self.config.dump_path) + if not os.path.exists(self.config.dump_path): + Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True) + file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR) + file_check.common_check() + self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") + cur_rank = self.current_rank if self.current_rank is not None else '' + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + if not os.path.exists(dump_dir): + Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True) + if self.config.task in self.data_collector.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + self.data_collector.update_dump_paths( + dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None) + + def register_hook_new(self): + logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task)) + if self.config.level == "L1": + api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) + api_register.api_set_hook_func() diff --git a/debug/accuracy_tools/atat/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py similarity index 68% rename from debug/accuracy_tools/atat/mindspore/task_handler_factory.py rename to debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py index 4f80e4e89c92156762ea0e4c4ed3302cc5c31f5f..7b7e6fd889c775a4491e824c1f73e6021cb99350 100644 --- a/debug/accuracy_tools/atat/mindspore/task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py @@ -1,6 +1,6 @@ -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.dump.dump_tool_factory import DumpToolFactory -from atat.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory +from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory class TaskHandlerFactory: diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/msprobe/msprobe.py similarity index 74% rename from debug/accuracy_tools/atat/atat.py rename to debug/accuracy_tools/msprobe/msprobe.py index 90f8215b102d4f120b879773797c4b4864b25d6b..9c96b193994410403e5be9bc592e76999f6bc8b2 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -15,19 +15,20 @@ import argparse import sys -from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command -from atat.pytorch.parse_tool.cli import parse as cli_parse -from atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut -from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \ +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command +from msprobe.pytorch.parse_tool.cli import parse as cli_parse +from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut +from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \ _api_precision_compare_command -from atat.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \ +from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \ _run_overflow_check_command +from msprobe.pytorch.torchair_compare.torchair_compare_cli import torchair_compare_parser, torchair_compare_cli def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, - description="atat(ascend training accuracy tools), [Powered by MindStudio].\n" + description="msprobe(mindstudio probe), [Powered by MindStudio].\n" "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" f"For any issue, refer README.md first", ) @@ -46,6 +47,8 @@ def main(): help='Number of splits for parallel processing. Range: 1-64') _api_precision_compare_parser(api_precision_compare_cmd_parser) _run_overflow_check_parser(run_overflow_check_cmd_parser) + torchair_compare_cmd_parser = subparsers.add_parser("torchair_compare") + torchair_compare_parser(torchair_compare_cmd_parser) if len(sys.argv) == 1: parser.print_help() sys.exit(0) @@ -61,6 +64,8 @@ def main(): _api_precision_compare_command(args) elif sys.argv[3] == "run_overflow_check": _run_overflow_check_command(args) + elif sys.argv[3] == "torchair_compare": + torchair_compare_cli(args) if __name__ == "__main__": diff --git a/debug/accuracy_tools/atat/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py similarity index 74% rename from debug/accuracy_tools/atat/pytorch/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/__init__.py index 482e850f7baa845bd831e0d4728e841661b9345b..58ab1ac35a3dcb25474c1da02f4d84f9de526346 100644 --- a/debug/accuracy_tools/atat/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -2,3 +2,4 @@ from .debugger.precision_debugger import PrecisionDebugger from .common.utils import seed_all from .compare.acc_compare import compare from .compare.distributed_compare import compare_distributed +from .visualization.graph_service import compare_graph, build_graph diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py similarity index 93% rename from debug/accuracy_tools/atat/pytorch/advisor/advisor.py rename to debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py index 43b3f40f97948808a987bd4211530cfca2cb025a..fe5bf1efb1f0e520c9cf77cd1cb808650e66e548 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py @@ -17,12 +17,12 @@ import os -from atat.pytorch.advisor.advisor_result import AdvisorResult -from atat.pytorch.advisor.advisor_const import AdvisorConst -from atat.pytorch.common.log import logger -from atat.core.common.utils import CompareException -from atat.core.common.file_check import FileChecker -from atat.core.common.const import Const, CompareConst, FileCheckConst +from msprobe.pytorch.advisor.advisor_result import AdvisorResult +from msprobe.pytorch.advisor.advisor_const import AdvisorConst +from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import CompareException +from msprobe.core.common.file_utils import FileChecker +from msprobe.core.common.const import Const, CompareConst, FileCheckConst class Advisor: """ diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_const.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_const.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/advisor/advisor_const.py rename to debug/accuracy_tools/msprobe/pytorch/advisor/advisor_const.py diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py similarity index 90% rename from debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py rename to debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py index a24fa2a1155501d91eb2462528f71824f091f318..58b76d3c8479321eb354e27c785a38d4ca3d8aaa 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py +++ b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py @@ -17,10 +17,10 @@ import os import time -from atat.pytorch.advisor.advisor_const import AdvisorConst -from atat.pytorch.common.log import logger -from atat.core.common.const import Const, FileCheckConst -from atat.core.common.file_check import change_mode +from msprobe.pytorch.advisor.advisor_const import AdvisorConst +from msprobe.pytorch.common.log import logger +from msprobe.core.common.const import Const, FileCheckConst +from msprobe.core.common.file_utils import change_mode class AdvisorResult: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/.keep b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/.keep similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/.keep rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/.keep diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/.keep b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/.keep similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/.keep rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/.keep diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py similarity index 58% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index 0aceb691b2530bd8117396ef34c9cd4154c76716..ca6bb1627e736ded4050bc3fd0bab9f5d64a1d51 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -1,10 +1,8 @@ import os import yaml -from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path -from atat.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps -from atat.core.common.file_check import FileOpen - -WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) +from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path +from msprobe.core.common.file_utils import FileOpen +from msprobe.pytorch.pt_config import RunUTConfig class Config: @@ -14,9 +12,17 @@ class Config: config = yaml.safe_load(file) self.config = {key: self.validate(key, value) for key, value in config.items()} - def validate(self, key, value): + def __getattr__(self, item): + return self.config[item] + + def __str__(self): + return '\n'.join(f"{key}={value}" for key, value in self.config.items()) + + @staticmethod + def validate(key, value): validators = { 'white_list': list, + 'black_list': list, 'error_data_path': str, 'precision': int } @@ -27,22 +33,13 @@ class Config: if key == 'precision' and value < 0: raise ValueError("precision must be greater than 0") if key == 'white_list': - if not isinstance(value, list): - raise ValueError("white_list must be a list type") - if not all(isinstance(i, str) for i in value): - raise ValueError("All elements in white_list must be of str type") - invalid_api = [i for i in value if i not in WrapApi] - if invalid_api: - raise ValueError( - f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") + RunUTConfig.check_filter_list_config(key, value) + if key == 'black_list': + RunUTConfig.check_filter_list_config(key, value) + if key == 'error_data_path': + RunUTConfig.check_error_data_path_config(value) return value - def __getattr__(self, item): - return self.config[item] - - def __str__(self): - return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) yaml_path = os.path.join(cur_path, "config.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py similarity index 96% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index 9e1b02c0154760f468c8d603e4bf385777258434..7855a51e4b472fd3d22b5ad9ee2f7e4d4a0d39e5 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -28,10 +28,10 @@ except ImportError: else: IS_GPU = False -from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory -from atat.core.common.const import Const, FileCheckConst -from atat.core.common.utils import CompareException +from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_utils import FileChecker, FileOpen, change_mode, create_directory +from msprobe.core.common.const import Const, FileCheckConst +from msprobe.core.common.utils import CompareException class DumpException(CompareException): @@ -166,6 +166,7 @@ def initialize_save_path(save_path, dir_name): os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) data_path_checker = FileChecker(data_path, FileCheckConst.DIR) data_path_checker.common_check() + return data_path def write_pt(file_path, tensor): diff --git a/debug/accuracy_tools/atat/pytorch/debugger/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/debugger/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py similarity index 98% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index 3982c167cca13617dc19332621b5ebc9fce12997..1bb19cc048e88c9353a19069031bb8acfdae05e9 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -2,8 +2,8 @@ import torch import numpy as np -from atat.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS -from atat.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS +from msprobe.core.common.const import CompareConst DEFAULT_THRESHOLD = 1 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index f73c83c4881a5c7b9664314a2051ba8ba770c0a1..d85abfe9ddaf24ce19ecfef965886db93e16761c 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -7,19 +7,19 @@ from collections import namedtuple import torch import pandas as pd -from atat.pytorch.api_accuracy_checker.common.utils import write_csv -from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \ +from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv +from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \ check_inf_or_nan -from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn -from atat.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from atat.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory -from atat.pytorch.common.log import logger -from atat.core.common.utils import CompareException -from atat.core.common.const import CompareConst, FileCheckConst +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path +from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory +from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import CompareException +from msprobe.core.common.const import CompareConst, FileCheckConst CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency', diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index ca35c8ed5daccf2c7c371ea1ae625c1e2435d743..ee49588288efc0a33c086913cc5624059de82272 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -3,18 +3,18 @@ import os from collections import namedtuple import torch import numpy as np -from atat.pytorch.common.log import logger -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv -from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \ ULPStandardApi, ThousandthStandardApi, apis_threshold -from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn -from atat.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err -from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.core.common.const import Const, CompareConst +from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig +from msprobe.core.common.const import Const, CompareConst ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py similarity index 98% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py index 9867a76fadff6eb071236cfabda93189547124a3..fb6d5dcc0f1c8b67ec2b67e8b419e8407cdc8d6d 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py @@ -1,4 +1,4 @@ -from atat.core.common.const import CompareConst +from msprobe.core.common.const import CompareConst class CompareColumn: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index b7b32e41e47ac69798f0bf0c15c5cdb929574a11..f8450b64b56a24fde76f0fc5a7be46479b5059ea 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -5,10 +5,10 @@ import math import numpy as np import torch import yaml -from atat.core.common.utils import CompareException -from atat.core.common.const import Const -from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileOpen +from msprobe.core.common.utils import CompareException +from msprobe.core.common.const import Const +from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_utils import FileOpen current_time = time.strftime("%Y%m%d%H%M%S") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml similarity index 77% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 7f26c72aa3102fa658c44caf7f5ce789cb23c46d..2dac535dc0501f6e47f0cdcc48bd88e1d73ab0dd 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -1,4 +1,5 @@ white_list: [] +black_list: [] error_data_path: './' precision: 14 \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/.keep b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/.keep similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/.keep rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/.keep diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py similarity index 95% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 97b2dcc7e404c399aca0838b57650c0c54933bf4..51477ddd46efe607b4407814c753d07a6f714ad9 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,10 +20,12 @@ import math import torch import numpy -from atat.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api -from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, get_full_data_path, CompareException -from atat.pytorch.common.log import logger -from atat.core.common.const import Const +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api +from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ + CompareException +from msprobe.core.common.file_utils import FileChecker +from msprobe.pytorch.common.log import logger +from msprobe.core.common.const import Const, FileCheckConst TORCH_TYPE = ["torch.device", "torch.dtype"] TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] @@ -86,12 +88,13 @@ def gen_real_tensor(data_path, convert_type): convert_type: convert ori_type to dist_type flag. """ data_path = os.path.realpath(data_path) - check_file_or_directory_path(data_path) + data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE) + data_path = data_path_checker.common_check() if not data_path.endswith('.pt') and not data_path.endswith('.npy'): error_info = f"The file: {data_path} is not a pt or numpy file." raise CompareException(CompareException.INVALID_FILE_ERROR, error_info) if data_path.endswith('.pt'): - data = torch.load(data_path).cpu() + data = torch.load(data_path, map_location=torch.device('cpu')) else: data_np = numpy.load(data_path) data = torch.from_numpy(data_np) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py similarity index 86% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index d2ab9c1e952001f4551044f4395662dd25931e08..049f6e9de5f590d5eb78a9ff2614a570c8bb88cb 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -9,14 +9,14 @@ import threading from collections import namedtuple from itertools import cycle from tqdm import tqdm -from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \ +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \ get_validated_details_csv_path, preprocess_forward_content -from atat.pytorch.api_accuracy_checker.compare.compare import Comparator -from atat.pytorch.common import parse_json_info_forward_backward -from atat.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \ +from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator +from msprobe.pytorch.common import parse_json_info_forward_backward +from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \ check_path_before_create, create_directory -from atat.pytorch.common.log import logger -from atat.core.common.const import FileCheckConst +from msprobe.pytorch.common.log import logger +from msprobe.core.common.const import FileCheckConst def split_json_file(input_file, num_splits, filter_api): @@ -68,7 +68,7 @@ signal.signal(signal.SIGTERM, signal_handler) ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', - 'result_csv_path', 'total_items', 'real_data_path']) + 'result_csv_path', 'total_items', 'config_path']) def run_parallel_ut(config): @@ -90,7 +90,7 @@ def run_parallel_ut(config): *(['-j'] if config.jit_compile_flag else []), *(['-save_error_data'] if config.save_error_data_flag else []), '-csv_path', config.result_csv_path, - *(['-real_data_path', config.real_data_path] if config.real_data_path else []) + *(['-config', config.config_path] if config.config_path else []) ] return cmd @@ -110,14 +110,9 @@ def run_parallel_ut(config): def update_progress_bar(progress_bar, result_csv_path): while any(process.poll() is None for process in processes): - try: - with open(result_csv_path, 'r') as result_file: - completed_items = len(result_file.readlines()) - 1 - progress_bar.update(completed_items - progress_bar.n) - except FileNotFoundError: - logger.warning(f"Result CSV file not found: {result_csv_path}.") - except Exception as e: - logger.error(f"An unexpected error occurred while reading result CSV: {e}") + with FileOpen(result_csv_path, 'r') as result_file: + completed_items = len(result_file.readlines()) - 1 + progress_bar.update(completed_items - progress_bar.n) time.sleep(1) for api_info in config.api_files: @@ -175,7 +170,7 @@ def prepare_config(args): out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api) - + config_path = os.path.realpath(args.config_path) if args.config_path else None result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") if not args.result_csv_path: details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv") @@ -187,7 +182,7 @@ def prepare_config(args): logger.info(f"UT task details will be saved in {details_csv_path}") return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data, args.jit_compile, args.device_id, result_csv_path, - total_items, args.real_data_path) + total_items, config_path) def main(): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py similarity index 85% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index c5834e9a8c738b4766612263ccb0d5cae24add9d..1b9b26f9c0e3b954c577b68d35d8c5e086660b73 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -10,10 +10,13 @@ else: is_gpu = False import torch from tqdm import tqdm -from atat.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents -from atat.core.common.file_check import check_link -from atat.pytorch.common.log import logger +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info +from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents +from msprobe.core.common.file_utils import check_link +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.core.common.const import Const + def check_tensor_overflow(x): if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool: @@ -52,12 +55,12 @@ def check_data_overflow(x): def run_overflow_check(forward_file): logger.info("start UT test") - forward_content = get_json_contents(forward_file) + forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - run_torch_api(api_full_name, api_info_dict) + run_torch_api(api_full_name, api_info_dict, real_data_path) except Exception as err: - api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] + _, api_name, _ = api_full_name.split(Const.SEP) if "not implemented for 'Half'" in str(err): logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API " f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -68,11 +71,10 @@ def run_overflow_check(forward_file): logger.error(f"Run {api_full_name} UT Error: %s" % str(err)) -def run_torch_api(api_full_name, api_info_dict): +def run_torch_api(api_full_name, api_info_dict, real_data_path): torch.npu.clear_npu_overflow_flag() - api_type = api_full_name.split(".")[0] - api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0] - args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='') + api_type, api_name, _ = api_full_name.split(Const.SEP) + args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) if not need_grad: logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py similarity index 85% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index cd83a95801ffecf5bc28f9908378eb09f945d6a5..2fb709127d893320988b64db634c92b1eda46d60 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -18,28 +18,32 @@ else: import torch from tqdm import tqdm -from atat.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api -from atat.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \ +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api +from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args +from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \ initialize_save_path, UtDataProcessor -from atat.pytorch.api_accuracy_checker.compare.compare import Comparator -from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn -from atat.pytorch.hook_module.wrap_tensor import TensorOPTemplate -from atat.pytorch.hook_module.wrap_functional import FunctionalOPTemplate -from atat.pytorch.hook_module.wrap_torch import TorchOPTemplate -from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.pytorch.common.parse_json import parse_json_info_forward_backward -from atat.core.common.file_check import FileOpen, FileChecker, \ +from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate +from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate +from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate +from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate +from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate +from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig +from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.core.common.file_utils import FileOpen, FileChecker, \ change_mode, check_file_suffix, check_link, check_path_before_create, create_directory -from atat.pytorch.common.log import logger -from atat.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.pt_config import parse_json_config +from msprobe.core.common.const import Const, FileCheckConst, CompareConst current_time = time.strftime("%Y%m%d%H%M%S") UT_ERROR_DATA_DIR = 'ut_error_data' + current_time RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path']) + 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', + 'black_list', 'error_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} not_raise_dtype_set = {'type_as'} @@ -76,6 +80,12 @@ def exec_api(api_type, api_name, args, kwargs): if api_type == "Torch": torch_api = TorchOPTemplate(api_name, str, False) out = torch_api.forward(*args, **kwargs) + if api_type == "Aten": + torch_api = AtenOPTemplate(api_name, None, False) + out = torch_api.forward(*args, **kwargs) + if api_type == "NPU": + torch_api = NpuOPTemplate(api_name, None, False) + out = torch_api.forward(*args, **kwargs) return out @@ -176,8 +186,7 @@ def run_ut(config): logger.info(f"UT task result will be saved in {config.result_csv_path}") logger.info(f"UT task details will be saved in {config.details_csv_path}") if config.save_error_data: - error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR)) - logger.info(f"UT task error_datas will be saved in {error_data_path}") + logger.info(f"UT task error_datas will be saved in {config.error_data_path}") compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) with FileOpen(config.result_csv_path, 'r') as file: csv_reader = csv.reader(file) @@ -188,17 +197,17 @@ def run_ut(config): continue if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api continue + [_, api_name, _] = api_full_name.split(Const.SEP) try: - if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split(Const.SEP) - if api_name not in set(msCheckerConfig.white_list): - continue + if config.black_list and api_name in config.black_list: + continue + if config.white_list and api_name not in config.white_list: + continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) if config.save_error_data: - do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) + do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success) except Exception as err: - [_, api_name, _] = api_full_name.split(Const.SEP) if "expected scalar type Long" in str(err): logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -227,16 +236,16 @@ def is_unsupported_api(api_name): return flag -def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): +def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: - processor = UtDataProcessor(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR)) + processor = UtDataProcessor(error_data_path) for element in data_info.in_fwd_data_list: processor.save_tensors_in_element(api_full_name + '.forward.input', element) - processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_out) - processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_out) + processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output) + processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output) processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in) - processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad_out) + processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad) + processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): @@ -273,7 +282,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict if need_backward: if need_to_backward(grad_index, out): - backward_args = backward_content[api_full_name].get("grad_output") + backward_args = backward_content[api_full_name].get("input") grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0] bench_grad, _ = generate_cpu_params(grad, {}, False, api_name) bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out) @@ -314,14 +323,14 @@ def run_backward(args, grad, grad_index, out): return grad_out -def initialize_save_error_data(): - error_data_path = msCheckerConfig.error_data_path +def initialize_save_error_data(error_data_path): check_path_before_create(error_data_path) create_directory(error_data_path) - error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, + error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) error_data_path = error_data_path_checker.common_check() - initialize_save_path(error_data_path, UT_ERROR_DATA_DIR) + error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR) + return error_data_path def get_validated_result_csv_path(result_csv_path, mode): @@ -378,12 +387,10 @@ def _run_ut_parser(parser): help=" The path of accuracy_checking_result_{timestamp}.csv, " "when run ut is interrupted, enter the file path to continue run ut.", required=False) - parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str, - help=" In real data mode, the root directory for storing real data " - "must be configured.", - required=False) parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true", help=" Whether to filter the api in the api_info_file.", required=False) + parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str, + help=" The path of config.json", required=False) def preprocess_forward_content(forward_content): @@ -397,9 +404,9 @@ def preprocess_forward_content(forward_content): if key not in arg_cache: filtered_new_args = [ {k: v for k, v in arg.items() if k not in ['Max', 'Min']} - for arg in value['args'] if isinstance(arg, dict) + for arg in value['input_args'] if isinstance(arg, dict) ] - arg_cache[key] = (filtered_new_args, value['kwargs']) + arg_cache[key] = (filtered_new_args, value['input_kwargs']) filtered_new_args, new_kwargs = arg_cache[key] @@ -464,14 +471,22 @@ def run_ut_command(args): if args.result_csv_path: result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') details_csv_path = get_validated_details_csv_path(result_csv_path) + white_list = msCheckerConfig.white_list + black_list = msCheckerConfig.black_list + error_data_path = msCheckerConfig.error_data_path + if args.config_path: + _, task_config = parse_json_config(args.config_path, Const.RUN_UT) + white_list = task_config.white_list + black_list = task_config.black_list + error_data_path = task_config.error_data_path if save_error_data: if args.result_csv_path: time_info = result_csv_path.split('.')[0].split('_')[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info - initialize_save_error_data() + error_data_path = initialize_save_error_data(error_data_path) run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, real_data_path) + args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path) run_ut(run_ut_config) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json similarity index 100% rename from debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb06867371c6234583cabd485bcaa3dd671cb00c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py @@ -0,0 +1,15 @@ +import os +from pkgutil import iter_modules +from importlib import import_module + +""" +gpu and cpu not implement benchmark function, supplementary benchmarking function implementation +""" + +package_path = os.path.dirname(os.path.realpath(__file__)) +for _, module_name, _ in iter_modules([package_path]): + module = import_module(f"{__name__}.{module_name}") + for attr_name in dir(module): + attr = getattr(module, attr_name) + if callable(attr) and "npu_custom" not in attr_name: + globals()[attr_name] = attr diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py new file mode 100644 index 0000000000000000000000000000000000000000..caf21a604c68a61343c6dbafa2e8c604a2d9649f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py @@ -0,0 +1,28 @@ +import torch + + +def npu_apply_adam_w(beta1_power, beta2_power, lr, weight_decay, + beta1, beta2, eps, grad, max_grad_norm, amsgrad, maximize, out): + var, m, v = out + if amsgrad: + max_grad_norm = (torch.rand(var.shape) * 10.0 - 5.0).to(var.dtype) + beta1_power_out = beta1_power * beta1 + beta2_power_out = beta2_power * beta2 + var_t = var * (1 + (-lr * weight_decay)) + gt = -grad if maximize else grad + m_out = m * beta1 - (beta1 + (-1)) * gt + v_out = v * beta2 - (beta2 + (-1)) * gt * gt + + if amsgrad: + max_grad_norm_out = torch.max(max_grad_norm, v_out) + if (1 - beta2_power_out) == 0: + beta2_power_out -= eps + denom = torch.sqrt(torch.div(max_grad_norm_out, (1 - beta2_power_out))) + eps + else: + vraintain = torch.div(v_out, (1 - beta2_power_out)) + denom = torch.sqrt(vraintain) + eps + + if (1 - beta1_power_out) == 0: + beta1_power_out -= eps + var_out = var_t + torch.div(-lr * m_out, (1 - beta1_power_out)).div(denom) + return var_out.cpu(), m_out.cpu(), v_out.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..627bf11b64f3e62b6b305b7f7559d0fa69056ba6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py @@ -0,0 +1,19 @@ +def npu_confusion_transpose(data, perm, shape, transpose_first): + if transpose_first: + output = data.permute(*perm).contiguous().view(shape) + else: + output = data.view(shape).permute(*perm) + return output.cpu() + + +def npu_confusion_transpose_backward(grad, perm, shape, transpose_first): + shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm] + perm_cal = [0] * len(perm) + for i, perm_dim in enumerate(perm): + perm_cal[perm_dim] = i + + if transpose_first: + result = grad.permute(*perm_cal).reshape(shape_cal) + else: + result = grad.reshape(shape_cal).permute(*perm_cal) + return result.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a9ca080851b580e8d1554a75b2c6ee1aee5165 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py @@ -0,0 +1,55 @@ +import torch + + +def fast_gelu(input0): + attr = 1.702 + const_0 = 0 - attr + const_1 = 1 + const_2 = attr / 2 + + abs_x = torch.abs(input0) + mul_abs_x = abs_x * const_0 + exp_abs_x = torch.exp(mul_abs_x) + div_down = exp_abs_x + const_1 + + pn_x = input0 - abs_x + mul_pn_x = pn_x * const_2 + exp_pn_x = torch.exp(mul_pn_x) + div_up = input0 * exp_pn_x + div_down_rec = torch.reciprocal(div_down) + result = div_up * div_down_rec + + return result.cpu() + + +def npu_fast_gelu_backward(grad, input_x): + const_2 = 1.702 + const_3 = 1.0 + const_1 = 0.0 - const_2 + + # e^(-1.702x) + abs_x = torch.abs(input_x) + mul_abs_x = abs_x * const_1 + exp_x = torch.exp(mul_abs_x) + + # 1.702xe^(-1.702x) + add_2 = input_x * exp_x + add_2 = add_2 * const_2 + + # e^(1.702(x-|x|)) + pn_x = input_x - abs_x + mul_pn_x = pn_x * const_2 + exp_pn_x = torch.exp(mul_pn_x) + + # e^(-1.702x) + 1.702xe^(-1.702x) + e^(1.702(x-|x|)) + div_up = exp_x + add_2 + div_up = div_up + exp_pn_x + + # (e^(-1.702x)+1)^2 + div_down_i = exp_x + const_3 + div_down = div_down_i * div_down_i + div_down_rec = torch.reciprocal(div_down) + result_temp = div_up * div_down_rec + result = grad * result_temp + + return result.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..f6949c079e2b7928f2f0b30d5fbdb305dc0ba535 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py @@ -0,0 +1,6 @@ +import torch + + +def npu_layer_norm_eval(data, normalized_shape): + result = torch.nn.functional.layer_norm(data, normalized_shape) + return result.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..95db875edf6dea3103fcd6b0d8533a850d8edad8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py @@ -0,0 +1,12 @@ +import torch + + +def npu_linear(x, weight, bias): + output = torch.nn.functional.linear(x, weight, bias) + return output.cpu() + + +def npu_linear_backward(grad, input_data, weight): + input_grad = torch.matmul(grad, weight) + weight_grad = torch.matmul(grad.t(), input_data) + return input_grad.cpu(), weight_grad.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1c746ec1618bbbec7cb49914a7e71012047d6b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py @@ -0,0 +1,48 @@ +import torch + + +def matmul_backward(grad, self, other, mask): + grad_self, grad_other = None, None + dim_self = self.dim() + dim_other = other.dim() + + size_grad = list(grad.size()) + size_self = list(self.size()) + size_other = list(other.size()) + if dim_self == 1 and dim_other == 1: + grad_self = other.mul(grad) if mask[0] else grad_self + grad_other = self.mul(grad) if mask[1] else grad_other + elif dim_self == 2 and dim_other == 1: + grad_self = grad.unsqueeze(1).mm(other.unsqueeze(0)) if mask[0] else grad_self + grad_other = self.transpose(-1, -2).mm(grad.unsqueeze(1)).squeeze_(1) if mask[1] else grad_other + elif dim_self == 1 and dim_other == 2: + grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self + grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other + elif dim_self >= 3 and (dim_other == 1 or dim_other == 2): + view_size = 1 if dim_other == 1 else size_grad[-1] + unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size) + if mask[0]: + grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \ + .view(size_self) + if mask[1]: + unfolded_self = self.contiguous().view([-1, size_self[-1]]) + grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other) + elif (dim_self == 1 or dim_self == 2) and dim_other >= 3: + view_size = 1 if dim_self == 1 else size_grad[-2] + unfolded_grad_T = grad.view([-1, view_size]) \ + if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size]) + if mask[0]: + # create a 2D-matrix from other + unfolded_other_T = \ + other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2) + grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self) + if mask[1]: + size_other_T = size_other[:-2] + size_other_T.extend(size_other[::-1][:2]) + grad_other = \ + unfolded_grad_T.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_T).transpose(-1, -2) + else: + grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self + grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other + + return grad_self.cpu(), grad_other.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..63f1fa2a3b657bce4fc470edf75d8271c6666364 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py @@ -0,0 +1,421 @@ +import torch +import numpy as np +from einops import rearrange + +from msprobe.pytorch.common.utils import logger + +gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86 +softmax_build_mode = "QKV" # "MAX_SUM" + +""" +# 前向函数声明对比 +标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob +融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None, + atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, + next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0, + gen_mask_parallel=True, sync=False + +# 反向函数声明对比 +标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob +融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, + atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None, + attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647, + next_tockens=2147483647, inner_precise=0, seed=0, offset=0, + numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False +""" + + +def softmax_forward(x): + x_max = torch.max(x, dim=-1, keepdims=True)[0] + x_sub = x.sub(x_max) + y = torch.exp(x_sub) + x_sum = y.sum(dim=-1, keepdims=True) + res = y.div(x_sum) + return res, x_max, x_sum + + +def softmax_grad(dp, softmax_res): + muls = dp * softmax_res + muls_r = muls.sum(dim=-1, keepdims=True) + sub_r = dp - muls_r + res = sub_r * softmax_res + return res + + +def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype): + if num_kv_heads == 0 or num_kv_heads < num_heads: + raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.") + + factor = num_heads // num_kv_heads + kv_shape = kv_tensor.shape + B = kv_shape[0] + S = kv_shape[2] + D = kv_shape[3] + kv_res = torch.zeros([B, num_heads, S, D]).to(dtype) + for i in range(num_heads): + j = i // factor + kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :] + return kv_res + + +def calculate_qk(q, k, atten_mask, pse, scale): + if pse is None or len(pse.shape) == 0: + qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale) + else: + qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale) + if atten_mask is None or len(atten_mask.shape) == 0: + return qk + else: + qk = qk + atten_mask.bool() * (-40000.0) # -10000 + return qk + + +def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob): + qk = calculate_qk(q, k, atten_mask, pse, scale) + softmax_res, softmax_max, softmax_sum = softmax_forward(qk) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res + else: + drop_res = softmax_res * drop_mask * (1.0 / keep_prob) + y = torch.matmul(drop_res, v) + return y, softmax_max, softmax_sum + + +def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob): + dp = torch.matmul(dx, v.permute(0, 1, 3, 2)) + if drop_mask is None or len(drop_mask.shape) == 0: + drop_res = softmax_res.permute(0, 1, 3, 2) + dp_drop = dp + else: + drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2) + dp_drop = dp * drop_mask * (1.0 / keep_prob) + dv = torch.matmul(drop_res, dx) + softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale) + dq = torch.matmul(softmax_grad_res, k) + dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q) + return dq, dk, dv + + +def parse_bsnd_args(query, key, head_num, input_layout): + supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"] + B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None + + if not isinstance(input_layout, str) or input_layout not in supported_input_layout: + raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.") + + if input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + try: + if input_layout == "BSH": + B, S1, H1 = query.shape + _, S2, H2 = key.shape + D = H1 // N1 + N2 = H2 // D + elif input_layout == "SBH": + S1, B, H1 = query.shape + S2, _, H2 = key.shape + D = H1 // N1 + N2 = H2 // D + elif input_layout == "BSND": + B, S1, N1, D = query.shape + _, S2, N2, _ = key.shape + H1 = N1 * D + H2 = N2 * D + elif input_layout == "BNSD": + B, N1, S1, D = query.shape + _, N2, S2, _ = key.shape + H1 = N1 * D + H2 = N2 * D + except Exception as e: + raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e + + if D == 0: + raise ValueError(f"Value D must be non-zero.") + DTYPE = query.dtype + return B, S1, S2, N1, N2, D, H1, H2, DTYPE + + +def convert_from_bnsd(_input, input_layout): + if input_layout == "BSH": + # (B,N,S,D)=>(B,S,N*D) + out = rearrange(_input, 'b n s d -> b s (n d)').contiguous() + elif input_layout == "SBH": + # (B,N,S,D)=>(S,B,N*D) + out = rearrange(_input, 'b n s d -> s b (n d)').contiguous() + elif input_layout == "BSND": + # (B,N,S,D)=>(B,S,N,D) + out = rearrange(_input, 'b n s d -> b s n d').contiguous() + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + return out + + +def convert_to_bnsd(_input, n, input_layout): + # 默认"BNSD"无需处理 + if input_layout == "BSH": + # (B,S,N*D)=>(B,N,S,D) + out = rearrange(_input, 'b s (n d) -> b n s d', n=n) + elif input_layout == "SBH": + # (S,B,N*D)=>(B,N,S,D) + out = rearrange(_input, 's b (n d) -> b n s d', n=n) + elif input_layout == "BSND": + # (B,S,N,D)=>(B,N,S,D) + out = rearrange(_input, 'b s n d -> b n s d', n=n) + elif input_layout == "TND": + raise ValueError(f"input_layout {input_layout} does not supported for now.") + else: + out = _input + if out.dim() != 4: + raise ValueError(f"convert qkv format failed with input_layout {input_layout}.") + return out.to(gtype) + + +def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype): + """ + # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现 + ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype) + """ + shape = [S1, S2] + + if atten_mask is not None: + # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原 + if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4: + logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}") + + if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048: + if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)): + if sparse_mode == 2: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1)) + elif sparse_mode == 4: + atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + atten_mask = atten_mask_u + atten_mask_l + logger.debug(f"反向转换atten_mask {atten_mask.shape}") + return atten_mask.to(dtype) + + return atten_mask.to(dtype) + + if atten_mask is not None: + if atten_mask.dim() == 2: + if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2: + raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}") + shape = [S1, S2] + elif atten_mask.dim() == 4: + if atten_mask.shape[1] == 1: + shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2] + else: + shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2] + + if sparse_mode == 0: + atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + atten_mask = atten_mask_u + atten_mask_l + elif sparse_mode == 1: # no sparse + atten_mask = torch.from_numpy(np.zeros(shape)) + elif sparse_mode == 2: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1)) + elif sparse_mode == 3: + atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1)) + elif sparse_mode == 4: + atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1)) + atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1)) + atten_mask = atten_mask_u + atten_mask_l + # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS, + # 因此可以认为FA的输入已经是正确的atten_mask了 + return atten_mask.to(dtype) + + +def generate_kv(key, value, N1, N2): + # N不等长适配by cdy + if not (N1 == N2): + k_new = broadcast_kv(N1, N2, key, key.dtype) + v_new = broadcast_kv(N1, N2, value, value.dtype) + else: + k_new = key + v_new = value + return k_new, v_new + + +def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max)) + """ + logger.info("Using QKV to rebuild original softmax") + qk = calculate_qk(q, k, atten_mask, pse, scale) + softmax_res, x_max, x_sum = softmax_forward(qk) + return softmax_res + + +def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum): + """ + attention = softmax(QK^T/sqrt(d))V + softmax(x_i) = e^(x_i - x_max_i) / x_sum_i) + """ + logger.info("Using softmax_max and softmax_sum to rebuild original softmax") + qk = calculate_qk(q, k, atten_mask, pse, scale) + if softmax_max.shape[-1] == 0: + raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}") + repeat_dim = qk.shape[-1] // softmax_max.shape[-1] + softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div( + softmax_sum.repeat(1, 1, 1, repeat_dim)) + return softmax_res + + +def npu_fusion_attention_forward_patch(*args, **kwargs): + # query, key, value, head_num, input_layout + if len(args) != 5: + raise ValueError(f"Unsupported npu_fusion_attention args {args}.") + + B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4]) + if N1 == N2 and S1 == S2: + logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (N1 % N2 == 0 and N1 >= N2): + raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.") + + dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2, + "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE} + + new_kwargs = {"keep_prob": 1, + "scale": kwargs.get("scale", 1 / (D ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "atten_mask": kwargs.get("atten_mask")} + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention_backward_patch(*args, **kwargs): + if len(args) != 6: + raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.") + + B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5]) + if N1 == N2 and S1 == S2: + logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + else: + logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}") + if not (N1 % N2 == 0 and N1 >= N2): + raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.") + + dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2, + "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE} + + new_kwargs = {"keep_prob": 1, + "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)), + "sparse_mode": kwargs.get("sparse_mode", 0), + "prefix": kwargs.get("prefix"), + "pre_tockens": kwargs.get("pre_tockens", 2147483647), + "next_tockens": kwargs.get("next_tockens", 2147483647), + "pse": kwargs.get("pse"), + "padding_mask": kwargs.get("padding_mask"), + "softmax_max": kwargs.get("softmax_max"), + "softmax_sum": kwargs.get("softmax_sum"), + "softmax_in": kwargs.get("softmax_in"), + "attention_in": kwargs.get("attention_in"), + "seed": kwargs.get("seed", 0), + "offset": kwargs.get("offset", 0), + "numels": kwargs.get("numels", 0), + "atten_mask": kwargs.get("atten_mask")} + + return args, dims_kwargs, new_kwargs + + +def npu_fusion_attention(*args, **kwargs): + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs) + query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4] + N1 = dims_kwargs.get("N1") + N2 = dims_kwargs.get("N2") + S1 = dims_kwargs.get("S1") + S2 = dims_kwargs.get("S2") + B = dims_kwargs.get("B") + DTYPE = dims_kwargs.get("DTYPE") + atten_mask = new_kwargs.get("atten_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + scale = new_kwargs.get("scale") + + atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE) + query = convert_to_bnsd(query, N1, input_layout) + key = convert_to_bnsd(key, N2, input_layout) + value = convert_to_bnsd(value, N2, input_layout) + k_new, v_new = generate_kv(key, value, N1, N2) + out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new, + drop_mask=None, atten_mask=atten_mask, + pse=pse, scale=scale, + keep_prob=keep_prob) + if out_golden.dim() == 5: + out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3), + out_golden.size(4)) + out_golden = convert_from_bnsd(out_golden, input_layout) + + return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu() + + +def npu_fusion_attention_grad(*args, **kwargs): + # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob + new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs) + query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5] + N1 = dims_kwargs.get("N1") + N2 = dims_kwargs.get("N2") + S1 = dims_kwargs.get("S1") + S2 = dims_kwargs.get("S2") + B = dims_kwargs.get("B") + D = dims_kwargs.get("D") + DTYPE = dims_kwargs.get("DTYPE") + atten_mask = new_kwargs.get("atten_mask") + keep_prob = new_kwargs.get("keep_prob") + sparse_mode = new_kwargs.get("sparse_mode") + pre_tockens = new_kwargs.get("pre_tockens") + next_tockens = new_kwargs.get("next_tockens") + pse = new_kwargs.get("pse") + softmax_max = new_kwargs.get("softmax_max") + softmax_sum = new_kwargs.get("softmax_sum") + scale_value = new_kwargs.get("scale_value") + + atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE) + query = convert_to_bnsd(query, N1, input_layout) + dx = convert_to_bnsd(dx, N1, input_layout) + key = convert_to_bnsd(key, N2, input_layout) + value = convert_to_bnsd(value, N2, input_layout) + k_new, v_new = generate_kv(key, value, N1, N2) + + if softmax_build_mode == "QKV": + softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value) + else: + softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum) + + dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob) + + # N不等长适配by cdy + if not (N1 == N2): + if N2 == 0: + raise ValueError("dims_kwargs.N2 must be non-zero.") + G = int(N1 / N2) + dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D) + dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D) + + if dq.dim() == 5: + dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4)) + if dk.dim() == 5: + dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4)) + if dv.dim() == 5: + dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4)) + + dq = convert_from_bnsd(dq, input_layout) + dk = convert_from_bnsd(dk, input_layout) + dv = convert_from_bnsd(dv, input_layout) + + return dq.cpu(), dk.cpu(), dv.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e647312fdb221137e58f775d21e0605d21f98e07 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py @@ -0,0 +1,15 @@ +import torch + + +def npu_rms_norm(x, gamma, epsilon=1e-5): + rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon) + res = x * rstd * gamma + return res.cpu(), rstd.float().cpu() + + +def npu_rms_norm_backward(grad, x, gamma, rstd): + mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True) + grad_x = (grad * gamma - x * rstd * mean_gy) * rstd + grad_gamma = x * grad * rstd + return grad_x.cpu(), grad_gamma.cpu() + diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0fda5f73fced0913ff1fd755a68e0dfb6652fd --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py @@ -0,0 +1,52 @@ +import torch + + +def npu_rotary_mul(x, r1, r2): + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + output = r1 * x + r2 * x_new + return output.cpu() + + +def npu_rotary_mul_backward(dy_tensor, x, r1, r2): + x.requires_grad = True + r1.requires_grad = True + r2.requires_grad = True + # golden + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + golden_tensor = r1 * x + r2 * x_new + golden_tensor.backward(dy_tensor) + r1_shape = r1.shape + r1_grad = torch.zeros(r1_shape).type(torch.float32) + r2_grad = torch.zeros(r1_shape).type(torch.float32) + x1, x2 = torch.chunk(x.float(), 2, -1) + x_new2 = torch.cat((-x2, x1), dim=-1) + x_shape = x.shape + h = x.float() + grad = dy_tensor.float() + condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and + ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and + (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3])) + condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and + ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and + (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3])) + condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and + ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and + (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3])) + if condition_1: + for i in range(x_shape[0]): + for j in range(x_shape[2]): + r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :]) + r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :]) + elif condition_2: + for i in range(x_shape[0]): + for j in range(x_shape[1]): + r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :]) + r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :]) + elif condition_3: + for i in range(x_shape[1]): + for j in range(x_shape[2]): + r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :]) + r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :]) + return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..8717aebaf902db8af37fe5d8ec5cb054fd83439b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py @@ -0,0 +1,26 @@ +import torch + + +def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask): + if fixed_triu_mask: + mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device) + dtype = x.dtype + x = (x * scale).masked_fill(mask, value=-10000) + x = x - torch.max(x, dim=-1, keepdims=True)[0] + x = torch.exp(x.float()) + y = torch.div(x, torch.sum(x, dim=-1, keepdims=True)) + return y.to(dtype).cpu() + + +def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask): + if fixed_triu_mask: + mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device) + dtype = y_grad.dtype + y_grad = y_grad.float() + y = y.float() + x_grad = y_grad * y + x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True) + x_grad = x_grad * y + x_grad = x_grad * scale + x_grad = x_grad.masked_fill(mask, value=0) + return x_grad.to(dtype).cpu() diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..e03c975a50aaf37573645365d2e6228706070844 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py @@ -0,0 +1,55 @@ +import torch + + +def npu_swiglu(x, dim=-1): + tensor_dtype = x.dtype + + inTensors = torch.chunk(x, 2, dim=dim) + if tensor_dtype == torch.float32: + tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0)) + output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1]) + else: + tensor_self_float = inTensors[0].type(torch.float) + tensor_other_float = inTensors[1].type(torch.float) + tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type( + torch.float32) * tensor_other_float + output_data = tensor_out_float.type(tensor_dtype) + return output_data.cpu() + + +def npu_swiglu_backward(grad, x, dim=-1): + tensor_dtype = grad.dtype + in_tensors = torch.chunk(x, 2, dim=dim) + tensor_grad_out = grad + + if tensor_dtype == torch.float16: + tensor_out1 = torch.mul( + torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))), + tensor_grad_out.type(torch.float32)).type(torch.float16) + tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32), + swish(1, in_tensors[0].type(torch.float32))).type(torch.float16) + output = torch.cat((tensor_out1, tensor_out2), dim) + elif tensor_dtype == torch.bfloat16: + tensor_self_float = in_tensors[0].type(torch.float) + tensor_other_float = in_tensors[1].type(torch.float) + tensor_gradout_float = tensor_grad_out.type(torch.float) + + tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type( + torch.float32) * tensor_other_float + tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float + tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim) + output = tensor_out_float.type(torch.bfloat16) + else: + tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out) + tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0])) + output = torch.cat((tensor_out1, tensor_out2), dim) + return output.cpu() + + +def swish_grad(beta, x): + return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta + + +def swish(beta, x): + return x * torch.sigmoid(beta * x) + diff --git a/debug/accuracy_tools/atat/pytorch/common/__init__.py b/debug/accuracy_tools/msprobe/pytorch/common/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/common/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/common/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/common/compare_script.template b/debug/accuracy_tools/msprobe/pytorch/common/compare_script.template similarity index 100% rename from debug/accuracy_tools/atat/pytorch/common/compare_script.template rename to debug/accuracy_tools/msprobe/pytorch/common/compare_script.template diff --git a/debug/accuracy_tools/atat/pytorch/common/log.py b/debug/accuracy_tools/msprobe/pytorch/common/log.py similarity index 81% rename from debug/accuracy_tools/atat/pytorch/common/log.py rename to debug/accuracy_tools/msprobe/pytorch/common/log.py index e496e9b72ad449c24dd3f2a76a9a149d0f2eff1e..cea518fa47b977ad08b5fded3401a63bf3a29d03 100644 --- a/debug/accuracy_tools/atat/pytorch/common/log.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/log.py @@ -1,9 +1,9 @@ import os import time import sys -from atat.pytorch.common.utils import get_rank_if_initialized -from atat.core.common.log import BaseLogger -from atat.core.common.exceptions import DistributedNotInitializedError +from msprobe.pytorch.common.utils import get_rank_if_initialized +from msprobe.core.common.log import BaseLogger +from msprobe.core.common.exceptions import DistributedNotInitializedError class PyTorchLogger(BaseLogger): diff --git a/debug/accuracy_tools/atat/pytorch/common/parse_json.py b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/common/parse_json.py rename to debug/accuracy_tools/msprobe/pytorch/common/parse_json.py index a938f5f0da9ea465923157ad4131ce72bb92962f..89edd834cf67c8006f577f1d33fee32b3b6dc751 100644 --- a/debug/accuracy_tools/atat/pytorch/common/parse_json.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py @@ -1,5 +1,7 @@ import json -from atat.core.common.exceptions import ParseJsonException + +from msprobe.core.common.exceptions import ParseJsonException +from msprobe.core.common.file_utils import FileOpen def parse_json_info_forward_backward(json_path): @@ -11,7 +13,7 @@ def parse_json_info_forward_backward(json_path): api_name = '.'.join(name_struct[:-1]) return api_name - with open(json_path, 'r') as f: + with FileOpen(json_path, 'r') as f: dump_json = json.load(f) real_data_path = dump_json.get("dump_data_dir") diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py similarity index 87% rename from debug/accuracy_tools/atat/pytorch/common/utils.py rename to debug/accuracy_tools/msprobe/pytorch/common/utils.py index 4b413ac57502ad777bf68140fb57854a691811b5..181491488f91049148c51827387edc53d21d41cf 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -14,13 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import logging import os import random import stat import torch +import torch.distributed as dist import numpy as np from functools import wraps -from atat.core.common.exceptions import DistributedNotInitializedError +from msprobe.core.common.exceptions import DistributedNotInitializedError try: import torch_npu @@ -29,7 +31,6 @@ except ImportError: else: is_gpu = False - torch_without_guard_version_list = ['2.1', '2.2'] for version in torch_without_guard_version_list: if torch.__version__.startswith(version): @@ -222,3 +223,36 @@ class Const: CONVERT_API = { "int32_to_int64": ["cross_entropy"] } + + +def get_tensor_rank(in_feat, out_feat): + if dist.is_initialized(): + return dist.get_rank() + + def get_tensor_rank_single(x): + if isinstance(x, (list, tuple)): + if len(x) > 0: + return get_tensor_rank_single(x[0]) + elif isinstance(x, torch.Tensor): + device = x.device + if device.type != 'cpu': + return device.index + return None + + in_rank = get_tensor_rank_single(in_feat) + out_rank = get_tensor_rank_single(out_feat) + tensor_rank = in_rank if in_rank else out_rank + return tensor_rank + + +def _create_logger(level=logging.INFO): + logger_ = logging.getLogger() + logger_.setLevel(level) + ch = logging.StreamHandler() + ch.setLevel(level) + logger_.addHandler(ch) + return logger_ + + +log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO +logger = _create_logger(log_level) diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py similarity index 94% rename from debug/accuracy_tools/atat/pytorch/compare/acc_compare.py rename to debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index 061c9cdfca872bb76ad8ffa0e292fb24ce1a4ada..16533accb26680e341a71e8f40f30e2ed899301f 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -27,15 +27,17 @@ from openpyxl.styles import PatternFill from collections import namedtuple from dataclasses import dataclass -from atat.pytorch.compare.match import graph_mapping -from atat.pytorch.compare.highlight import HighlightRules, get_header_index -from atat.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, get_error_message -from atat.pytorch.advisor.advisor import Advisor -from atat.pytorch.common.log import logger -from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ +from msprobe.pytorch.compare.match import graph_mapping +from msprobe.pytorch.compare.highlight import HighlightRules, get_header_index +from msprobe.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \ + get_error_message +from msprobe.pytorch.advisor.advisor import Advisor +from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ format_value, check_file_not_exists, check_configuration_param, task_dumppath_get -from atat.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory -from atat.core.common.const import Const, CompareConst, FileCheckConst +from msprobe.core.common.file_utils import FileChecker, change_mode, FileOpen, create_directory +from msprobe.core.common.const import Const, CompareConst, FileCheckConst +from msprobe.core.common.exceptions import FileCheckException def check_graph_mode(a_op_name, b_op_name): @@ -490,6 +492,10 @@ def compare_by_op(op_name, op_name_mapping_dict, input_parma): error_file = error.filename n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE error_flag = True + except FileCheckException: + error_file = data_name + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag) if not error_flag: @@ -525,8 +531,10 @@ def handle_inf_nan(n_value, b_value): return n_value, b_value -def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False): +def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False): """找到单个API中需要高亮的行""" + if md5_compare: + return npu_max_index = get_header_index('NPU max', summary_compare) bench_max_index = get_header_index('Bench max', summary_compare) max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) @@ -582,7 +590,7 @@ def get_name_and_state(name): return api_name, state -def find_compare_result_error_rows(result_df, highlight_dict, summary_compare): +def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): """将dataframe根据API分组,并找到有误差的算子用于高亮""" result = result_df.values start, input_num, output_num, end = 0, 0, 0, len(result_df) @@ -600,7 +608,7 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare): else: output_num = num find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - summary_compare) + summary_compare, md5_compare) num, last_api_name, last_state = 1, api_name, state start += input_num + output_num input_num, output_num = 1, 0 @@ -611,7 +619,7 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare): input_num = num else: output_num = num - find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare) + find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare) def highlight_rows_xlsx(result_df, highlight_dict, file_path): @@ -637,7 +645,11 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path): elif (i - 2) in highlight_dict['yellow_rows']: ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid") - wb.save(file_path) + try: + wb.save(file_path) + except Exception as e: + logger.error('Save result file failed') + raise CompareException(CompareException.WRITE_FILE_ERROR) from e change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -647,8 +659,8 @@ def compare(input_parma, output_path, stack_mode=False, auto_analyze=True, summary_compare, md5_compare = task_dumppath_get(input_parma) check_configuration_param(stack_mode, auto_analyze, fuzzy_match) create_directory(output_path) - check_compare_param(input_parma, output_path, stack_mode, summary_compare, md5_compare) - except CompareException as error: + check_compare_param(input_parma, output_path, summary_compare, md5_compare) + except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') sys.exit(error.code) compare_core(input_parma, output_path, stack_mode=stack_mode, @@ -696,7 +708,7 @@ def compare_core(input_parma, output_path, **kwargs): if not md5_compare and not summary_compare: result_df = _do_multi_process(input_parma, result_df) - find_compare_result_error_rows(result_df, highlight_dict, summary_compare) + find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) highlight_rows_xlsx(result_df, highlight_dict, file_path) if auto_analyze: advisor = Advisor(result_df, output_path) @@ -738,7 +750,7 @@ def parse(pkl_file, module_name_prefix): logger.info(summary_info) -def op_item_parse(item, op_name, index, item_list=[], top_bool=True): +def op_item_parse(item, op_name, index, item_list=None, top_bool=True): if item_list is None: item_list = [] if item is None or (isinstance(item, dict) and not item): @@ -756,9 +768,14 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): else: full_op_name = op_name else: - full_op_name = op_name + '.' + str(index) + full_op_name = op_name + Const.SEP + str(index) if isinstance(item, dict): - if 'dtype' in item: + if 'type' not in item: + for kwarg in item: + kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None) + item_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif 'dtype' in item: parsed_item = item parsed_item['full_op_name'] = full_op_name item_list.append(parsed_item) @@ -800,8 +817,8 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): else: resolve_api_special_parameters(item, full_op_name, item_list) else: - for j in range(len(item)): - op_item_parse(item[j], full_op_name, j, item_list=item_list, top_bool=False) + for j, item_spec in enumerate(item): + op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False) return item_list @@ -861,13 +878,13 @@ def read_op(op_data, op_name): op_parsed_list += output_parsed_list output_parsed_list.clear() if 'backward' in op_name: - if 'grad_input' in op_data: - input_item = op_data['grad_input'] + if 'input' in op_data: + input_item = op_data['input'] input_parsed_list = op_item_parse(input_item, op_name + '_input', None) op_parsed_list = input_parsed_list.copy() input_parsed_list.clear() - if 'grad_output' in op_data: - output_item = op_data['grad_output'] + if 'output' in op_data: + output_item = op_data['output'] output_parsed_list = op_item_parse(output_item, op_name + '_output', None) op_parsed_list += output_parsed_list output_parsed_list.clear() @@ -952,8 +969,9 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False if npu_ops_queue: for npu_data in npu_ops_queue: get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + result_to_csv(md5_compare, summary_compare, stack_mode, result) - header = [] +def result_to_csv(md5_compare, summary_compare, stack_mode, result): if md5_compare: header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] elif summary_compare: diff --git a/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py similarity index 85% rename from debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py rename to debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index b89adc1581e8b0cf76ca28c41dbf0e86738ebece..b2b3a6672c49f38ed8c27ef42f13b52bbdb051f5 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -17,11 +17,12 @@ import os import sys import re -from atat.core.common.utils import CompareException, check_compare_param, \ +from msprobe.core.common.utils import CompareException, check_compare_param, \ check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid -from atat.pytorch.compare.acc_compare import compare_core -from atat.core.common.file_check import create_directory -from atat.pytorch.common.log import logger +from msprobe.pytorch.compare.acc_compare import compare_core +from msprobe.core.common.file_utils import create_directory +from msprobe.core.common.exceptions import FileCheckException +from msprobe.pytorch.common.log import logger def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): @@ -86,12 +87,11 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): 'or use compare() api and manually match the ranks.') raise CompareException(CompareException.INVALID_PATH_ERROR) for nr, br in zip(npu_ranks, bench_ranks): - n_dir = os.path.join(npu_dump_dir, nr) - b_dir = os.path.join(bench_dump_dir, br) - s_dir = b_dir - npu_json_path = extract_json(n_dir, stack_json=False) - bench_json_path = extract_json(b_dir, stack_json=False) - stack_json_path = extract_json(s_dir, stack_json=True) + npu_data_dir = os.path.join(npu_dump_dir, nr) + bench_data_dir = os.path.join(bench_dump_dir, br) + npu_json_path = extract_json(npu_data_dir, stack_json=False) + bench_json_path = extract_json(bench_data_dir, stack_json=False) + stack_json_path = extract_json(npu_data_dir, stack_json=True) dump_result_param = { 'npu_json_path': npu_json_path, @@ -103,8 +103,8 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): summary_compare, md5_compare = task_dumppath_get(dump_result_param) check_configuration_param(stack_mode, auto_analyze, fuzzy_match) create_directory(output_path) - check_compare_param(dump_result_param, output_path, stack_mode=stack_mode, summary_compare=summary_compare) - except CompareException as error: + check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare) + except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') sys.exit(error.code) compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, diff --git a/debug/accuracy_tools/atat/pytorch/compare/highlight.py b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/compare/highlight.py rename to debug/accuracy_tools/msprobe/pytorch/compare/highlight.py index 3a6898dedbb6910d0c1c9e55f80b55eb4fa0ed3c..82f0022f8b5d4a0c6472b749d4937bfe39ef8a86 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py @@ -1,8 +1,8 @@ import math import abc import numpy as np -from atat.core.common.utils import get_header_index -from atat.core.common.const import CompareConst +from msprobe.core.common.utils import get_header_index +from msprobe.core.common.const import CompareConst class HighlightCheck(abc.ABC): diff --git a/debug/accuracy_tools/atat/pytorch/compare/mapping.yaml b/debug/accuracy_tools/msprobe/pytorch/compare/mapping.yaml similarity index 100% rename from debug/accuracy_tools/atat/pytorch/compare/mapping.yaml rename to debug/accuracy_tools/msprobe/pytorch/compare/mapping.yaml diff --git a/debug/accuracy_tools/atat/pytorch/compare/match.py b/debug/accuracy_tools/msprobe/pytorch/compare/match.py similarity index 91% rename from debug/accuracy_tools/atat/pytorch/compare/match.py rename to debug/accuracy_tools/msprobe/pytorch/compare/match.py index 148fbb7d640b3fde5e3292508c29404e277cde84..ca335f7d8ec9bda9b594ff9cd2c9d5ed1f7a053f 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/match.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/match.py @@ -1,7 +1,7 @@ import os import yaml -from atat.core.common.file_check import FileOpen -from atat.core.common.utils import CompareException +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.utils import CompareException class AtenIrMapping(): diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py similarity index 98% rename from debug/accuracy_tools/atat/pytorch/compare/npy_compare.py rename to debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py index 0cf4c6c00a0a671a6ee46dc6dacce48af6e67adf..5a0feb4cd4a63b6f2ab680c9e9a0f0e92b594e2e 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py @@ -1,8 +1,8 @@ import abc import numpy as np -from atat.core.common.utils import format_value -from atat.core.common.const import Const, CompareConst -from atat.pytorch.common.log import logger +from msprobe.core.common.utils import format_value +from msprobe.core.common.const import Const, CompareConst +from msprobe.pytorch.common.log import logger def handle_inf_nan(n_value, b_value): diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/__init__.py b/debug/accuracy_tools/msprobe/pytorch/debugger/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/debugger/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py rename to debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index 1ad69701e4167db3e2e9f61f7f725f024ec21d16..851a61d04b1442dac30fb93b8b3c888943d10a1f 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -1,6 +1,6 @@ -from atat.pytorch.common import seed_all -from atat.pytorch.common.log import logger -from atat.core.common.const import Const +from msprobe.pytorch.common import seed_all +from msprobe.pytorch.common.log import logger +from msprobe.core.common.const import Const class DebuggerConfig: @@ -13,6 +13,7 @@ class DebuggerConfig: self.seed = common_config.seed if common_config.seed else 1234 self.is_deterministic = common_config.is_deterministic self.enable_dataloader = common_config.enable_dataloader + self.enable_step_auto_dump = common_config.enable_step_auto_dump self.scope = task_config.scope if task_config.scope else [] self.list = task_config.list if task_config.list else [] self.data_mode = task_config.data_mode if task_config.data_mode else ["all"] @@ -21,7 +22,7 @@ class DebuggerConfig: self.acl_config = common_config.acl_config if common_config.acl_config else "" self.is_forward_acl_dump = True self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS - self.overflow_num = task_config.overflow_num if task_config.overflow_num else 1 + self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1 self.framework = Const.PT_FRAMEWORK if self.task == Const.FREE_BENCHMARK: @@ -46,9 +47,8 @@ class DebuggerConfig: raise ValueError("backward_input must be configured when scope contains 'backward'") if Const.BACKWARD in self.scope[0]: self.is_forward_acl_dump = False - for index in range(len(self.scope)): - # Do this replace operation to let the acl backward dump can be done in forward hook. - self.scope[index] = self.scope[index].replace(Const.BACKWARD, Const.FORWARD) + for index, scope_spec in enumerate(self.scope): + self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD) self.backward_input[self.scope[index]] = self.backward_input_list[index] seed_all(self.seed, self.is_deterministic) diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py similarity index 49% rename from debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py rename to debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 140d829bedc6fb1243820c0d0ccc84af42f01424..2d566800906cfce2eca2404211b7f6e261f44d5d 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -1,10 +1,10 @@ import torch from torch.utils.data import dataloader -from atat.pytorch.debugger.debugger_config import DebuggerConfig -from atat.pytorch.service import Service -from atat.pytorch.common.log import logger -from atat.pytorch.pt_config import parse_json_config -from atat.core.common.exceptions import MsaccException +from msprobe.pytorch.debugger.debugger_config import DebuggerConfig +from msprobe.pytorch.service import Service +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.pt_config import parse_json_config +from msprobe.core.common.exceptions import MsprobeException class PrecisionDebugger: @@ -25,13 +25,17 @@ class PrecisionDebugger: level=None, model=None, step=None, + enable_step_auto_dump=None ): if not hasattr(self, "initialized"): + self.api_origin = False self.initialized = True self.model = self.check_model_valid(model) common_config, task_config = parse_json_config(config_path, task) if step: common_config.step = step + if enable_step_auto_dump: + common_config.enable_step_auto_dump = enable_step_auto_dump self.config = DebuggerConfig( common_config, task_config, task, dump_path, level ) @@ -41,11 +45,37 @@ class PrecisionDebugger: if self.enable_dataloader: logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) + self.enable_step_auto_dump = self.config.enable_step_auto_dump + if self.enable_step_auto_dump: + self.start_for_optimizer() @property def instance(self): return self._instance + @staticmethod + def check_model_valid(model): + if not model or isinstance(model, torch.nn.Module): + return model + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。" + ) + + # 非侵入式dump使能方法 + @classmethod + def start_for_optimizer(cls): + instance = cls._instance + if not instance: + raise Exception("No instance of PrecisionDebugger found.") + elif torch.__version__ < '2.0.0': + raise Exception("Pytorch version is earlier than 2.0.0 does not support optimizer hooks. \ + Please turn off enable_step_auto_dump and use the start, stop and step methods of PrecisionDebugger instead.") + else: + logger.info_on_rank_0("The enable_step_auto_dump is on and start()/stop()/step() will not take effect.") + logger.warning_on_rank_0("Customized optimizer iteration is not supported. Please use start, stop and step methods when using customized optimizer.") + instance.service.hook_optimizer(instance.model) + instance.service.start(instance.model) + @classmethod def start(cls): instance = cls._instance @@ -53,8 +83,25 @@ class PrecisionDebugger: raise Exception("No instance of PrecisionDebugger found.") if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, start() skipped.") + elif instance.enable_step_auto_dump: + logger.warning_on_rank_0("optimizer is enabled, start() skipped.") else: - instance.service.start(instance.model) + instance.service.start(instance.model, instance.api_origin) + instance.api_origin = False + + # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump + @classmethod + def forward_backward_dump_end(cls): + instance = cls._instance + if not instance: + raise Exception("PrecisionDebugger instance is not created.") + if instance.enable_dataloader: + logger.warning_on_rank_0("DataLoader is enabled, forward_backward_dump_end() skipped.") + elif instance.enable_step_auto_dump: + logger.warning_on_rank_0("optimizer is enabled, forward_backward_dump_end() skipped.") + else: + instance.service.forward_backward_dump_end() + instance.api_origin = True @classmethod def stop(cls): @@ -63,23 +110,20 @@ class PrecisionDebugger: raise Exception("PrecisionDebugger instance is not created.") if instance.enable_dataloader: logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.") + elif instance.enable_step_auto_dump: + logger.warning_on_rank_0("optimizer is enabled, stop() skipped.") else: instance.service.stop() @classmethod def step(cls): - if not cls._instance: + instance = cls._instance + if not instance: raise Exception("PrecisionDebugger instance is not created.") - cls._instance.service.step() - - @staticmethod - def check_model_valid(model): - if not model or isinstance(model, torch.nn.Module): - return model - raise MsaccException( - MsaccException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。" - ) - + elif instance.enable_step_auto_dump: + logger.warning_on_rank_0("optimizer is enabled, step() skipped.") + else: + instance.service.step() def iter_tracer(func): def func_wrapper(*args, **kwargs): diff --git a/debug/accuracy_tools/atat/pytorch/doc/FAQ.md b/debug/accuracy_tools/msprobe/pytorch/doc/FAQ.md similarity index 72% rename from debug/accuracy_tools/atat/pytorch/doc/FAQ.md rename to debug/accuracy_tools/msprobe/pytorch/doc/FAQ.md index 19a434a194686e1d75c411785e052c7dc0b085f9..8d12a72928ee4d9977b1db05a72f2189b2edd3c1 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/FAQ.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/FAQ.md @@ -22,15 +22,15 @@ 6. 添加预检工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 - 答:注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。 + 答:注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。 7. 添加预检工具后F.gelu触发ValueError报错:`activation_func must be F.gelu`等。 - 答:注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。 + 答:注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。 8. 添加预检工具后触发AsStrided算子相关的报错,或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 - 答:注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。 + 答:注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。 9. Tensor 魔法函数具体对应什么操作? @@ -75,7 +75,7 @@ ### dump指定融合算子 -dump指定操作当前支持dump指定融合算子的输入输出,需要在mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml中添加,比如以下代码段调用的softmax融合算子 +dump指定操作当前支持dump指定融合算子的输入输出,需要在mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml中添加,比如以下代码段调用的softmax融合算子 ``` def npu_forward_fused_softmax(self, input_, mask): @@ -111,7 +111,7 @@ torch版本和硬件差异属于正常情况。 **故障现象** -使用atat工具时,报错: error code: EI0006。 +使用msprobe工具时,报错: error code: EI0006。 **故障原因** @@ -136,7 +136,7 @@ torch.npu.set_device('npu:0') torch.npu.set_device(f'npu:{rank}') ``` -如果运行精度比对功能遇到这个报错,尝试安装最新版本的atat。 +如果运行精度比对功能遇到这个报错,尝试安装最新版本的msprobe。 ### 4. dump得到的VF_lstm_99_forward_input.1.0.npy、VF_lstm_99_forward_input.1.1.npy类似的数据是否正常? @@ -147,7 +147,7 @@ torch.npu.set_device(f'npu:{rank}') 在比对脚本中,设置stack_mode=True,例如: ``` -from atat.pytorch import compare +from msprobe.pytorch import compare dump_result_param={ "npu_json_path": "./npu_dump/dump.json", "bench_json_path": "./gpu_dump/dump.json", @@ -174,20 +174,20 @@ compare(dump_result_param, output_path="./output", stack_mode=True) ### 9. dump.json文件中的某些api的dtype类型为float16,但是读取此api的npy文件显示的dtype类型为float32 -- atat工具在dump数据时需要将原始数据从npu to cpu上再转换为numpy类型,npu to cpu的逻辑和gpu to cpu是保持一致的,都存在dtype可能从float16变为float32类型的情况,如果出现dtype不一致的问题,最终dump数据的dtype以pkl文件为准。 +- msprobe工具在dump数据时需要将原始数据从npu to cpu上再转换为numpy类型,npu to cpu的逻辑和gpu to cpu是保持一致的,都存在dtype可能从float16变为float32类型的情况,如果出现dtype不一致的问题,最终dump数据的dtype以pkl文件为准。 -### 10. 使用dataloader后raise异常Exception("atat: exit after iteration {}". format(max(self.config.step)) +### 10. 使用dataloader后raise异常Exception("msprobe: exit after iteration {}". format(max(self.config.step)) - 正常现象,dataloader通过raise结束程序,堆栈信息可忽略。 -### 11. 添加atat工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 +### 11. 添加msprobe工具后截取操作报错:`IndexError: too many indices for tensor of dimension x` 或 `TypeError: len() of a 0-d tensor`。 -- 注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。 +- 注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- __getitem__`,工具会跳过dump该API。如果是需要dump的关键位置API也可以考虑根据报错堆栈信息注释引发报错的类型检查。 -### 12. 添加atat工具后F.gelu触发ValueError报错:`activation_func must be F.gelu`等。 +### 12. 添加msprobe工具后F.gelu触发ValueError报错:`activation_func must be F.gelu`等。 -- 注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置api也可以考虑根据报错堆栈信息注释引发报错的类型检查。 +- 注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中functional:下的的`- gelu`,工具会跳过dump该API。如果是需要dump的关键位置api也可以考虑根据报错堆栈信息注释引发报错的类型检查。 -### 13. 添加atat工具后触发AsStrided算子相关的报错,或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 +### 13. 添加msprobe工具后触发AsStrided算子相关的报错,或者编译相关的报错,如:`Failed to compile Op [AsStrided]`。 -- 注释工具目录mstt/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。 +- 注释工具目录mstt/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml文件中Tensor:下的`- t`和`- transpose`。 diff --git a/debug/accuracy_tools/atat/pytorch/doc/api_accuracy_checker.md b/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md similarity index 84% rename from debug/accuracy_tools/atat/pytorch/doc/api_accuracy_checker.md rename to debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md index 0e45a4e83fb720411b07923e7888cff404148d8d..41b97098ae95394511a375bc9f17a2041de6a826 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/api_accuracy_checker.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md @@ -20,8 +20,8 @@ 精度预检操作流程如下: -1. 在NPU和GPU环境下分别安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 -2. 在NPU训练脚本内添加atat工具dump接口PrecisionDebugger采集待预检数据。详见《[精度数据采集](./dump.md)》。 +1. 在NPU和GPU环境下分别安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 +2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger,采集待预检数据。详见《[精度数据采集](./dump.md)》,注意需要配置level="L1"。 3. 将NPU环境下dump的预检数据拷贝至GPU环境。 4. 在NPU和GPU环境下分别执行run_ut,生成结果用于最终api_precision_compare操作的输入。详见“**run_ut预检操作**”。 5. 将NPU和GPU执行run_ut生成的`accuracy_checking_details_{timestamp}.csv`结果文件拷贝至同一环境下。 @@ -43,7 +43,7 @@ run_ut预检操作包括如下场景: 1. 将API信息输入给run_ut模块运行精度检测并比对,运行如下命令: ```bash - atat -f pytorch run_ut -api_info ./dump.json + msprobe -f pytorch run_ut -api_info ./dump.json ``` | 参数名称 | 说明 | 是否必选 | @@ -51,20 +51,22 @@ run_ut预检操作包括如下场景: | -api_info或--api_info_file | 指定API信息文件dump.json。 | 是 | | -save_error_data | 保存精度未达标的API输入输出数据。 | 否 | | -o或--out_path | 指定run_ut执行结果存盘路径,默认“./”(相对于run_ut的路径)。 | 否 | + | | | | | -j或--jit_compile | 开启jit编译。 | 否 | | -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 | | -csv_path或--result_csv_path | 指定本次运行中断时生成的`accuracy_checking_result_{timestamp}.csv`文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的`accuracy_checking_result_{timestamp}.csv`文件。详见“**断点续检**”。 | run_ut操作中断后继续执行场景下必选 | | -f或--filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的API。适用于模型较大且重复API较多的场景。 | 否 | + | -config或--config_path | 指定预检操作过程中的额外配置(包括黑名单、白名单等)的[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,默认未配置。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md#pytorch场景task配置为run_ut)》。 | 否 | run_ut执行结果包括`accuracy_checking_result_{timestamp}.csv`和`accuracy_checking_details_{timestamp}.csv`两个文件。`accuracy_checking_result_{timestamp}.csv`是API粒度的,标明每个API是否通过测试。建议用户先查看`accuracy_checking_result_{timestamp}.csv`文件,对于其中没有通过测试的或者特定感兴趣的API,根据其API name字段在`accuracy_checking_details_{timestamp}.csv`中查询其各个输出的达标情况以及比较指标。详细介绍请参见“**预检结果**”。 2. (可选)如果需要保存比对不达标的输入和输出数据,可以在run_ut执行命令结尾添加-save_error_data,例如: ```bash - atat -f pytorch run_ut -api_info ./dump.json -save_error_data + msprobe -f pytorch run_ut -api_info ./dump.json -save_error_data ``` - 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过修改mstt/debug/accuracy_tools/api_accuracy_checker目录下,config.yaml文件的error_data_path参数来配置保存路径,详见“config.yaml文件说明”。 + 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过error_data_path参数来配置保存路径,error_data_path参数在[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件或config.yaml文件配置,config.json文件需要在run_ut操作时通过-config参数指定,config.yaml文件详见“**config.yaml文件说明**”。 #### 使用multi_run_ut.py执行多线程预检 @@ -73,7 +75,7 @@ multi_run_ut.py脚本,可以并行执行多个run_ut操作,从而降低预 命令示例如下: ```bash -atat -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3 +msprobe -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3 ``` | 参数名称 | 说明 | 是否必选 | @@ -96,26 +98,68 @@ atat -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3 断点续检操作通过如下命令执行: ```bash -atat -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv +msprobe -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv ``` -#### API预检白名单 +#### API预检黑名单和白名单 -run_ut过程支持API预检白名单,操作方式如下: +run_ut过程支持API预检黑名单和白名单,通过如下文件配置black_list(黑名单)或white_list(白名单)参数来指定不需要或需要预检的API名称: -修改mstt/debug/accuracy_tools/api_accuracy_checker目录下config.yaml文件的white_list参数,配置需要预检的API名称,详见“config.yaml文件说明”。 +- 配置[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,config.json文件需要在run_ut操作时通过-config参数指定。 +- 配置config.yaml文件,详见“**config.yaml文件说明**”。 + +config.json文件的优先级高于config.yaml文件,即执行config.json文件时,config.yaml文件的配置不生效。 ### config.yaml文件说明 -config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单等功能。 +config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单、黑名单等功能。操作步骤如下: + +1. 查找msprobe工具安装路径。 + + ```bash + pip show mindstudio-probe + ``` + + 输出结果如下示例: + + ```bash + Name: mindstudio-probe + Version: 1.0 + Summary: This is a pytorch precision comparison tools + Home-page: + Author: + Author-email: + License: + Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages + Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel + Required-by: + ``` + + Location字段为msprobe工具的安装路径,那么config.yaml文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml + +2. 进入config.yaml文件 + + ```bash + vi /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml + ``` + +3. 修改config.yaml文件参数。 + + ```yaml + white_list: [] + black_list: [] + error_data_path: './' + precision: 14 + ``` -文件路径为:mstt/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/config.yaml + | 参数名称 | 说明 | 是否必选 | + | --------------- | ------------------------------------------------------------ | -------- | + | white_list | API dump白名单,仅对指定的API进行dump。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 | + | black_list | API dump黑名单,被指定的API不进行dump。参数示例:black_list=["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 | + | error_data_path | 配置保存精度未达标的API输入输出数据路径。参数示例"error_data_path": "./"。默认为当前路径。 | 否 | + | precision | 浮点数表示位数,默认取小数点后14位。 | 否 | -| 参数名称 | 说明 | 是否必选 | -| --------------- | ------------------------------------------------------------ | -------- | -| white_list | API dump白名单,指定dump具体API数据,也可以直接配置预检的API白名单,详细请参见“**API预检白名单**”。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 | -| error_data_path | 配置保存精度未达标的API输入输出数据路径。 | 否 | -| precision | 浮点数表示位数,默认取小数点后14位。 | 否 | + 说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。 ## 预检结果 @@ -203,7 +247,7 @@ API预检通过测试,则在`accuracy_checking_details_{timestamp}.csv`文件 需要同时获取NPU和GPU环境下run_ut操作的预检结果`accuracy_checking_details_{timestamp}.csv`文件。执行如下命令进行NPU和GPU预检结果的比对: ```bash -atat -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_details_{timestamp}.csv -gpu /home/xxx/gpu/accuracy_checking_details_{timestamp}.csv -o /home/xxx/ +msprobe -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_details_{timestamp}.csv -gpu /home/xxx/gpu/accuracy_checking_details_{timestamp}.csv -o /home/xxx/ ``` | 参数名称 | 说明 | 是否必选 | diff --git a/debug/accuracy_tools/atat/pytorch/doc/dump.md b/debug/accuracy_tools/msprobe/pytorch/doc/dump.md similarity index 66% rename from debug/accuracy_tools/atat/pytorch/doc/dump.md rename to debug/accuracy_tools/msprobe/pytorch/doc/dump.md index 1e401b4f5a22fc137872f8678591b9c1eca094b6..7e393cd1026f5881de98ad830ddcc29bc6946673 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/dump.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/dump.md @@ -1,8 +1,8 @@ # **精度数据采集** -atat工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。 +msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方式来采集精度数据。 -执行dump操作需要安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 +执行dump操作需要安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 ## dump接口介绍 @@ -12,7 +12,7 @@ atat工具主要通过在训练脚本内添加dump接口并启动训练的方式 通过加载dump配置文件的方式来确定dump操作的详细配置。 -可以在from atat.pytorch import PrecisionDebugger和模型初始化之间的任意位置添加该接口。 +PrecisionDebugger接口可以在from msprobe.pytorch import PrecisionDebugger之后的位置添加。详细使用可参考“**示例代码**”或“**model配置代码示例**”。 **原型** @@ -20,7 +20,7 @@ atat工具主要通过在训练脚本内添加dump接口并启动训练的方式 PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model=None, step=None) ``` -说明:上述参数除config_path和model外,其他参数均在[config.json](../../config)文件中可配,此处的参数优先级高于[config.json](../../config)文件中的配置,而config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config)文件。 +说明:上述参数除config_path和model外,其他参数均在[config.json](../../config)文件中可配,此处的参数优先级高于[config.json](../../config)文件中的配置,而config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config)文件。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md)》。 **参数说明** @@ -44,7 +44,7 @@ import torch import torch.nn as nn import torch_npu import torch.nn.functional as F -from atat.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger torch.npu.set_device("npu:0") #定义一个简单的网络 @@ -77,9 +77,9 @@ if __name__ == "__main__" **功能说明** -启动函数。 +dump启动函数。 -在模型初始化之后的任意位置添加。 +在模型初始化之后的位置添加。需要与stop函数一起添加在for循环内。 **原型** @@ -93,9 +93,9 @@ debugger.start() **功能说明** -停止函数。 +dump停止函数。 -在**start**函数之后的任意位置添加。 +在**start**函数之后的任意位置添加。若需要dump反向数据,则需要添加在反向计算代码(如loss.backward)之后。 **原型** @@ -105,13 +105,33 @@ debugger.stop() 该函数为类函数,可以使用debugger.stop()也可以使用PrecisionDebugger.stop()。 +### forward_backward_dump_end函数 + +**功能说明** + +dump停止函数。用于dump指定代码的前反向数据。 + +在**start**函数之后,反向计算代码(如loss.backward)之前的任意位置添加,可以dump **start**函数和该函数之间的前反向数据,可以通过调整**start**函数与该函数的位置,来指定需要dump的代码块。 + +要求**stop**函数添加在反向计算代码(如loss.backward)之后,此时该函数与**stop**函数之间的代码不会被dump。 + +使用示例参见“**示例代码 > 扩展示例**”。 + +**原型** + +```Python +forward_backward_dump_end() +``` + +该函数为类函数,可以使用debugger.forward_backward_dump_end()也可以使用PrecisionDebugger.forward_backward_dump_end()。 + ### step函数 **功能说明** 结束标识。 -在最后一个**stop**函数后或一个step结束的位置添加。 +在最后一个**stop**函数后或一个step结束的位置添加。需要与start函数一起添加在for循环内。 **原型** @@ -123,24 +143,57 @@ debugger.step() ## 示例代码 +### 基础操作 + +如下示例可dump完整代码的前反向数据。 + ```Python -from atat.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger + +# 请勿将PrecisionDebugger的初始化流程插入到循环代码中 debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") -# 请勿将以上初始化流程插入到循环代码中 -# 模型初始化 -# 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop() -debugger.start() +# 模型、损失函数的定义及初始化等操作 +# ... -# 需要dump的代码片段1 +# 数据集迭代的位置一般为模型训练开始的位置 +for data, label in data_loader: + debugger.start() # 开启数据dump -debugger.stop() -debugger.start() + # 如下是模型每个step执行的逻辑 + output = model(data) + #... + loss.backward() + + debugger.stop() # 关闭数据dump + debugger.step() # 结束一个step的dump +``` -# 需要dump的代码片段2 +### 扩展示例 -debugger.stop() -debugger.step() +如下示例dump指定代码块前反向数据。 + +```Python +from msprobe.pytorch import PrecisionDebugger + +# 请勿将PrecisionDebugger的初始化流程插入到循环代码中 +debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path") + +# 模型、损失函数的定义及初始化等操作 +# ... + +# 数据集迭代的位置一般为模型训练开始的位置 +for data, label in data_loader: + debugger.start() # 开启数据dump + + # 如下是模型每个step执行的逻辑 + output = model(data) + debugger.forward_backward_dump_end() # 插入该函数到start函数之后,只dump start函数到该函数之间代码的前反向数据,本函数到stop函数之间的数据则不dump + #... + loss.backward() + + debugger.stop() # 关闭数据dump + debugger.step() # 结束一个step的dump ``` ## dump结果文件介绍 @@ -193,7 +246,7 @@ pt文件保存的前缀和PyTorch对应关系如下: ## 工具支持的API列表 -atat工具维护固定的API支持列表,若需要删除或增加dump的API,可以在atat/pytorch/hook_module/support_wrap_ops.yaml文件内手动修改,如下示例: +msprobe工具维护固定的API支持列表,若需要删除或增加dump的API,可以在msprobe/pytorch/hook_module/support_wrap_ops.yaml文件内手动修改,如下示例: ```Python functional: # functional为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_1.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_1.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_1.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_1.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_2.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_2.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_2.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_2.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_3.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_3.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_3.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_3.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_4.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_4.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/BLOOM-7B_4.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/BLOOM-7B_4.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_1.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_1.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_1.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_1.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_2.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_2.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_2.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_2.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_3.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_3.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_3.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_3.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_4.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_4.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_4.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_4.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_5.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_5.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_5.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_5.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_6.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_6.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_6.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_6.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_7.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_7.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_7.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_7.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_8.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_8.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/GPT-3_8.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/GPT-3_8.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_1.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_1.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_1.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_1.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_2.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_2.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/YOLOV5S_2.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/YOLOV5S_2.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_details.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_details.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_details.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_details.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_result.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_result.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/accuracy_checking_result.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/accuracy_checking_result.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_details.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_details.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_details.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_details.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_result.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_result.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/api_precision_compare_result.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/api_precision_compare_result.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/auto_analyze_log.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/auto_analyze_log.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/auto_analyze_log.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/auto_analyze_log.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl_md5.png.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/compare_result_pkl_md5.png.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/cpu_info.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/cpu_info.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/cpu_info.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/cpu_info.png diff --git a/debug/accuracy_tools/atat/pytorch/doc/img/module_compare.png b/debug/accuracy_tools/msprobe/pytorch/doc/img/module_compare.png similarity index 100% rename from debug/accuracy_tools/atat/pytorch/doc/img/module_compare.png rename to debug/accuracy_tools/msprobe/pytorch/doc/img/module_compare.png diff --git "a/debug/accuracy_tools/atat/pytorch/doc/atat\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" "b/debug/accuracy_tools/msprobe/pytorch/doc/msprobe\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" similarity index 97% rename from "debug/accuracy_tools/atat/pytorch/doc/atat\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" rename to "debug/accuracy_tools/msprobe/pytorch/doc/msprobe\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" index ed175ff30172a54d8d4868097599ab8518b45e4f..c9db3ae78d7d47330cf6cddcc66c741c77a63514 100644 --- "a/debug/accuracy_tools/atat/pytorch/doc/atat\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" +++ "b/debug/accuracy_tools/msprobe/pytorch/doc/msprobe\347\262\276\345\272\246\345\267\245\345\205\267\346\225\260\346\215\256dump\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" @@ -1,4 +1,4 @@ -# atat精度工具标准性能基线报告 +# msprobe精度工具标准性能基线报告 ## 环境信息 @@ -16,7 +16,7 @@ CANN:8.0.T2 ## 模型信息和性能基线 -大模型在使用atat工具dump数据时,建议先简化模型层数,减少dump数据量。 +大模型在使用msprobe工具dump数据时,建议先简化模型层数,减少dump数据量。 以下场景的性能基线测试数据均为多次测试后取平均值,因此实际运行时性能数据可能会根据环境状态稍有浮动。 diff --git a/debug/accuracy_tools/atat/pytorch/doc/parse_tool.md b/debug/accuracy_tools/msprobe/pytorch/doc/parse_tool.md similarity index 98% rename from debug/accuracy_tools/atat/pytorch/doc/parse_tool.md rename to debug/accuracy_tools/msprobe/pytorch/doc/parse_tool.md index 23000912910e8f95b4cb74c7983961918bd9a513..81efa10fa3ec4307603e24e9599c5a00367462d4 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/parse_tool.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/parse_tool.md @@ -6,10 +6,10 @@ ## 进入parse交互式界面 -安装atat工具后(详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节),可以通过使用命令 **atat -f pytorch parse** 进入交互式界面,如下所示: +安装msprobe工具后(详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节),可以通过使用命令 **msprobe -f pytorch parse** 进入交互式界面,如下所示: ```bash -atat -f pytorch parse +msprobe -f pytorch parse Parse >>> ``` @@ -23,7 +23,7 @@ Parse >>> Ctrl+C可以退出parse交互式界面。不退出parse交互式界面若需要执行非该界面下的内置Shell命令,且命令与parse交互式界面命令冲突时,非该界面命令需要使用run命令,在相关命令前加上run前缀,如下示例: ```bash -atat -f pytorch parse +msprobe -f pytorch parse Parse >>> run vim cli.py Parse >>> vim cli.py ``` diff --git a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_compare.md b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_compare.md similarity index 99% rename from debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_compare.md rename to debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_compare.md index e3537594c4f8c9e277ca867172875e3e28c23113..4bd05c73e21c4491ee8286366a6b987a60ee69ae 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_compare.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_compare.md @@ -44,7 +44,7 @@ compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs) 创建比对脚本,例如compare_distributed.py,拷贝如下代码,具体参数请根据实际环境修改。 ```Python -from atat.pytorch import * +from msprobe.pytorch import * compare_distributed('./npu_dump/step0', './gpu_dump/step0', './output') ``` @@ -77,7 +77,7 @@ compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_mat 单机单卡场景下创建比对脚本,例如compare.py,拷贝如下代码,具体参数请根据实际环境修改。 ```Python -from atat.pytorch import compare +from msprobe.pytorch import compare dump_result_param={ "npu_json_path": "./npu_dump/dump.json", "bench_json_path": "./gpu_dump/dump.json", @@ -96,7 +96,7 @@ compare(dump_result_param, output_path="./output", stack_mode=True) 以compare.py为例。 ```Python -from atat.pytorch import compare +from msprobe.pytorch import compare dump_result_param={ "npu_json_path": "./npu_dump/dump.json", "bench_json_path": "./gpu_dump/dump.json", diff --git a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_overview.md b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_overview.md similarity index 81% rename from debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_overview.md rename to debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_overview.md index 708d90b3487c47249c5f6a8b0f37671e8918e7e2..019451454877eddd3cf6e59cc1eef1c48fcf2a3c 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_overview.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_overview.md @@ -4,7 +4,7 @@ 在PyTorch训练网络,对同一模型或API调试过程中,遇到API相关的计算精度问题,定位时费时费力。 -atat的精度比对工具,用来进行PyTorch整网API粒度的数据dump、精度比对和溢出检测,从而定位PyTorch训练场景下的精度问题。 +msprobe的精度比对工具,用来进行PyTorch整网API粒度的数据dump、精度比对和溢出检测,从而定位PyTorch训练场景下的精度问题。 **使用场景** @@ -42,17 +42,17 @@ atat的精度比对工具,用来进行PyTorch整网API粒度的数据dump、 1. 准备CPU或GPU训练工程。 -2. 在环境下安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 +2. 在环境下安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 -3. 在训练脚本内添加atat工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。 +3. 在训练脚本内添加msprobe工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。 4. 执行训练dump数据。 5. 将CPU或GPU训练工程迁移为NPU训练工程。详见《[PyTorch模型迁移调优指南](https://www.hiascend.com/document/detail/zh/Pytorch/60RC1/ptmoddevg/trainingmigrguide/PT_LMTMOG_0003.html)》。 -6. 在NPU环境下安装atat工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 +6. 在NPU环境下安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 -7. 在NPU训练脚本内添加atat工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。 +7. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger采集标杆数据。详见《[精度数据采集](./dump.md)》。 8. NPU环境下执行训练dump数据。 diff --git a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_quickstart.md b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_quickstart.md similarity index 94% rename from debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_quickstart.md rename to debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_quickstart.md index c05302055687fdf6071befd7ff8ad77c9e32f2df..4b6ac9de2f075ad17ff8a594bc40df4c36692f5f 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/ptdbg_ascend_quickstart.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/ptdbg_ascend_quickstart.md @@ -1,8 +1,8 @@ # **精度比对工具** -本文主要介绍atat的精度比对工具的快速入门和场景化示例。 +本文主要介绍msprobe的精度比对工具的快速入门和场景化示例。 -本文介绍的操作需要安装atat工具,详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 +本文介绍的操作需要安装msprobe工具,详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。 本文介绍的操作主要是精度数据dump和精度比对,详细操作指导可参考《[精度数据采集](./dump.md)》和《[CPU或GPU与NPU精度数据比对](./ptdbg_ascend.md)》。 @@ -51,12 +51,12 @@ PyTorch训练场景的精度问题分析建议参考以下思路进行精度比 } ``` -2. 在训练脚本内添加atat工具,dump整网数据。 +2. 在训练脚本内添加msprobe工具,dump整网数据。 分别dump CPU或GPU以及NPU数据,在PyTorch训练脚本插入dump接口,示例代码如下(下面以NPU为例,CPU或GPU dump基本相同): ```python - from atat.pytorch import PrecisionDebugger + from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump") # 请勿将以上初始化流程插入到循环代码中 @@ -82,7 +82,7 @@ PyTorch训练场景的精度问题分析建议参考以下思路进行精度比 创建并配置精度比对脚本,以创建compare.py为例,示例代码如下: ```python - from atat.pytorch import compare + from msprobe.pytorch import compare dump_result_param={ "npu_json_path": "./npu_dump/dump.json", "bench_json_path": "./gpu_dump/dump.json", @@ -140,10 +140,10 @@ python3 compare.py } ``` -2. 在NPU训练脚本内添加atat工具,执行溢出检测dump。 +2. 在NPU训练脚本内添加msprobe工具,执行溢出检测dump。 ```python - from atat.pytorch import PrecisionDebugger + from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump") # 请勿将以上初始化流程插入到循环代码中 @@ -171,7 +171,7 @@ python3 compare.py 溢出解析工具执行命令如下: ```bash - atat -f pytorch run_overflow_check -api_info ./dump.json + msprobe -f pytorch run_overflow_check -api_info ./dump.json ``` 反向过程溢出的API暂不支持精度预检功能。 @@ -200,7 +200,7 @@ python3 compare.py 1. 创建比对脚本,例如compare_distributed.py,拷贝如下代码。 ```python - from atat.pytorch import * + from msprobe.pytorch import * compare_distributed('./npu_dump/step0', './gpu_dump/step0', './output') ``` @@ -219,7 +219,7 @@ python3 compare.py 多卡一般为多进程,须保证每个进程都正确调用PrecisionDebugger,或把PrecisionDebugger插入到import语句后,如: ```python -from atat.pytorch import PrecisionDebugger +from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump") ``` @@ -339,10 +339,10 @@ debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump" } ``` -2. 在训练脚本内添加atat工具,dump整网数据。 +2. 在训练脚本内添加msprobe工具,dump整网数据。 ```python - from atat.pytorch import PrecisionDebugger + from msprobe.pytorch import PrecisionDebugger debugger = PrecisionDebugger(config_path="./config.json", dump_path="./npu_dump") # 请勿将以上初始化流程插入到循环代码中 diff --git a/debug/accuracy_tools/atat/pytorch/doc/run_overflow_check.md b/debug/accuracy_tools/msprobe/pytorch/doc/run_overflow_check.md similarity index 95% rename from debug/accuracy_tools/atat/pytorch/doc/run_overflow_check.md rename to debug/accuracy_tools/msprobe/pytorch/doc/run_overflow_check.md index 1bdc4f354cfaf0bfbdf701baa7dfb05f3771e30b..b8c9c3b4c292886e2ef8229ec421a244ae38f92e 100644 --- a/debug/accuracy_tools/atat/pytorch/doc/run_overflow_check.md +++ b/debug/accuracy_tools/msprobe/pytorch/doc/run_overflow_check.md @@ -13,7 +13,7 @@ 2. 执行溢出API解析操作。 ```bash - atat -f pytorch run_overflow_check -api_info ./dump.json + msprobe -f pytorch run_overflow_check -api_info ./dump.json ``` | 参数名称 | 说明 | 是否必选 | diff --git a/debug/accuracy_tools/msprobe/pytorch/doc/torchair_compare.md b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_compare.md new file mode 100644 index 0000000000000000000000000000000000000000..a27ba388b030dbd534c6d79b160a8bb3b87bd5a8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_compare.md @@ -0,0 +1,43 @@ +# TorchAir训练场景-整网算子精度比对 + +TorchAir训练场景数据dump和比对能力继承自torchair推理场景。数据dump和比对基本资料可参考[基于torch图模式(torchair)推理场景](https://gitee.com/ascend/msit/blob/master/msit/docs/llm/TorchAir%E5%9C%BA%E6%99%AF-%E6%95%B4%E7%BD%91%E7%AE%97%E5%AD%90%E7%B2%BE%E5%BA%A6%E6%AF%94%E5%AF%B9.md)。 + +## 工具安装 + +1. 克隆mstt仓poc分支后,进入mstt/debug/accuracy_tools目录,执行python setup.py bdist_wheel,进行poc分支的msprobe包构建,构建完成后,在accuracy_tools目录下生成的dist文件夹,dist下生成whl安装包,安装whl包。 +2. 命令行中输入add_torchair_compare_path。提示torchair compare path added successfully,则说明可以使用图模式数据dump和比对能力。 + +## 约束 + +- 数据dump时,torch.compile()函数训练场景不接受参数dynamic=True。 +- FX模式dump训练脚本需要单独放在一个目录下。 + +## 比对操作 + +1. 数据dump + + 用法见[基于torch图模式(torchair)推理场景](https://gitee.com/ascend/msit/blob/master/msit/docs/llm/TorchAir%E5%9C%BA%E6%99%AF-%E6%95%B4%E7%BD%91%E7%AE%97%E5%AD%90%E7%B2%BE%E5%BA%A6%E6%AF%94%E5%AF%B9.md)。可参考[TorchAir场景Dump案例](./torchair_dump_sample.md)。 + +2. 精度比对 + + 比对通过命令行子命令torchair_compare执行,命令如下: + + ```bash + msprobe -f pytorch torchair_compare --my-path [my dump data] --golden-path [gold dump data] --output-path [output path] + ``` + +3. 查看比对结果,请参考《[PyTorch 场景的精度比对-精度比对结果分析](https://gitee.com/ascend/mstt/blob/e4206fa5268a96f991dc7a3444c6ce16040089ed/debug/accuracy_tools/msprobe/docs/10.accuracy_compare_PyTorch.md#3-精度比对结果分析)》章节。 + +TorchAir训练场景支持如下两种比对场景: + +- GE融合模式(默认)dump数据与FX dump数据精度比对 + + ```bash + msprobe -f pytorch torchair_compare --my-path {dump_path}/dump_{time_stamp} --golden-path {dump_path} + ``` + +- GE融合模式(默认)dump数据与GE关闭融合模式dump数据精度比对 + + ```bash + msprobe -f pytorch torchair_compare --my-path {dump_path}/dump_{time_stamp} --golden-path {dump_path}/dump_{time_stamp} + ``` diff --git a/debug/accuracy_tools/msprobe/pytorch/doc/torchair_dump_sample.md b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_dump_sample.md new file mode 100644 index 0000000000000000000000000000000000000000..b809113712e7a03cff7b9d55019b4576960a1f05 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/doc/torchair_dump_sample.md @@ -0,0 +1,131 @@ +# TorchAir训练场景Dump案例 + +## GE开启融合(默认)dump + +```python +import os +import numpy as np +import torch, torch_npu, torchvision +import torchair as tng +from msit_llm.dump import torchair_dump +import torch.optim as optim + +target_dtype = torch.float16 +model = torchvision.models.resnet50(pretrained=True).eval().to(target_dtype).npu() +if not os.path.exists('aa_224_224.npy'): + np.save('aa_224_224.npy', np.random.uniform(size=[1, 3, 224, 224])) +aa = torch.from_numpy(np.load('aa_224_224.npy')).to(target_dtype).npu() + +config = torchair_dump.get_ge_dump_config(dump_path="dump_gefusion_open_train") +npu_backend = tng.get_npu_backend(compiler_config=config) +model = torch.compile(model, backend=npu_backend) + +criterion = torch.nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + +if not os.path.exists('bb_1.pt'): + torch.save(torch.randint(0, 1000, (1,), dtype=torch.long), 'bb_1.pt') +bb = torch.load('bb_1.pt').npu() + +optimizer.zero_grad() +output = model(aa) +loss = criterion(output, bb) +print(loss) +print("loss finished") +loss.backward() +print("backward finished") +optimizer.step() +print(f"loss gradient: {loss.grad}") +``` + +## FX dump + +```python +import os +import numpy as np +import torch, torch_npu, torchvision +import torchair as tng +from msit_llm.dump import torchair_dump +import torch.optim as optim + +target_dtype = torch.float16 +model = torchvision.models.resnet50(pretrained=True).eval().to(target_dtype).npu() +if not os.path.exists('../aa_224_224.npy'): + np.save('../aa_224_224.npy', np.random.uniform(size=[1, 3, 224, 224])) +aa = torch.from_numpy(np.load('../aa_224_224.npy')).to(target_dtype).npu() + +config = torchair_dump.get_fx_dump_config() +npu_backend = tng.get_npu_backend(compiler_config=config) +model = torch.compile(model, backend=npu_backend) + +criterion = torch.nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + +if not os.path.exists('../bb_1.pt'): + torch.save(torch.randint(0, 1000, (1,), dtype=torch.long), '../bb_1.pt') +bb = torch.load('../bb_1.pt').npu() + +optimizer.zero_grad() +output = model(aa) +loss = criterion(output, bb) +print(loss) +print("loss finished") +loss.backward() +print("backward finished") +optimizer.step() +print(f"loss gradient: {loss.grad}") +``` + +## GE关闭融合模式dump + +```python +import os +import numpy as np +import torch, torch_npu, torchvision +import torchair as tng +from msit_llm.dump import torchair_dump +import torch.optim as optim + +target_dtype = torch.float16 +model = torchvision.models.resnet50(pretrained=True).eval().to(target_dtype).npu() +if not os.path.exists('aa_224_224.npy'): + np.save('aa_224_224.npy', np.random.uniform(size=[1, 3, 224, 224])) +aa = torch.from_numpy(np.load('aa_224_224.npy')).to(target_dtype).npu() + +config = torchair_dump.get_ge_dump_config(dump_path="dump_gefusion_close_train", fusion_switch_file='./fusion_switch.json') +npu_backend = tng.get_npu_backend(compiler_config=config) +model = torch.compile(model, backend=npu_backend) + +criterion = torch.nn.CrossEntropyLoss() +optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + +if not os.path.exists('bb_1.pt'): + torch.save(torch.randint(0, 1000, (1,), dtype=torch.long), 'bb_1.pt') +bb = torch.load('bb_1.pt').npu() + +optimizer.zero_grad() +output = model(aa) +loss = criterion(output, bb) +print(loss) +print("loss finished") +loss.backward() +print("backward finished") +optimizer.step() +print(f"loss gradient: {loss.grad}") +``` + +新建fusion_switch.json文件,内容如下: + +```json +{ + "Switch": { + "GraphFusion": { + "ALL": "off" + }, + "UBFusion": { + "ALL": "off" + } + } +} +``` + diff --git "a/debug/accuracy_tools/atat/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" "b/debug/accuracy_tools/msprobe/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" similarity index 95% rename from "debug/accuracy_tools/atat/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" rename to "debug/accuracy_tools/msprobe/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" index b2e373feb6cf7cb0c63dcff592939567b52738b4..05bebaf0a22d8d5d7886ec9937b32b4755caf872 100644 --- "a/debug/accuracy_tools/atat/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" +++ "b/debug/accuracy_tools/msprobe/pytorch/doc/\345\234\250\347\272\277\347\262\276\345\272\246\346\257\224\345\257\271.md" @@ -32,8 +32,8 @@ PyTorch NPU在线精度比对是ptdbg_ascend工具实现在PyTorch训练过程 1. 在NPU训练脚本中添加在线精度比对接口,示例如下: ```python - from atat.pytorch.common.utils import seed_all - from atat.pytorch.online_dispatch import PtdbgDispatch + from msprobe.pytorch.common.utils import seed_all + from msprobe.pytorch.online_dispatch import PtdbgDispatch # 在main函数开始前固定随机数 seed_all() @@ -74,12 +74,12 @@ PyTorch NPU在线精度比对是ptdbg_ascend工具实现在PyTorch训练过程 | process_num | 多进程并发数,默认为0。 | 否 | | debug | debug信息打印,默认为False。 | 否 | ### dump数据存盘说明 -dump数据存盘目录名格式:`atat_tag_rankid_{timestamp}`。 +dump数据存盘目录名格式:`msprobe_tag_rankid_{timestamp}`。 子目录下包含1个比对结果csv文件、cpu和npudump数据目录,npu目录下包含Aten IR在NPU上的输入输出的dump数据,由于CPU的输入是直接使用NPU的输入执行,因此cpu目录下只包含执行输出的dump数据。 ```bash -atat_rank4_20230911170521 +msprobe_rank4_20230911170521 ├── compare_result_rank4_20230911170521.csv ├── cpu │   ├── native_batch_norm_backward_10_output.0.npy diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py similarity index 43% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py index b9d41330a87605af37f7b538e8d7bcf56f9725f1..d234898c0df158308b070e4a9147c9dff0b67c8d 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/__init__.py @@ -1,6 +1,6 @@ -from atat.core.common.log import logger -from atat.core.common.exceptions import FreeBenchmarkException -from atat.core.common.const import Const +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import FreeBenchmarkException +from msprobe.core.common.const import Const from .main import FreeBenchmarkCheck from .common.params import UnequalRow diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py index 36b7a6491580a0137ff89958b4f6883dccbea4fc..c5e93be138d24af8c18858db483e397527fb4092 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/constant.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py @@ -2,8 +2,8 @@ from typing import Dict import numpy as np import torch -from atat.pytorch.free_benchmark.common.enums import FuzzThreshold -from atat.pytorch.free_benchmark.common.params import BenchmarkThd +from msprobe.pytorch.free_benchmark.common.enums import FuzzThreshold +from msprobe.pytorch.free_benchmark.common.params import BenchmarkThd class CommonField: @@ -52,6 +52,7 @@ class ThresholdConfig: DTYPE_PER_THD = { torch.float16: 1.002, + torch.bfloat16: 1.004, torch.float32: 1.0002, } BENCHMARK_THD_DICT = { @@ -60,6 +61,8 @@ class ThresholdConfig: torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4), } + TENSOR_SPLIT_MAX_CHUNK = 128 + class PreheatConfig: IF_PREHEAT = "if_preheat" diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/counter.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/counter.py index 186b75c71aeaf71fc2adab7ec38c7f00f6b7fdb7..b2f8c81f3a4ea57d712e49b0b58fc77747797323 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/counter.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/counter.py @@ -1,5 +1,5 @@ from collections import defaultdict -from atat.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig class PreheatCounter: diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/enums.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/enums.py diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py similarity index 93% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py index 440348d78c28d3f7cc816932ff12e83aa71915bc..bbfc245a635322f0cde6951663d4d76a009ee66a 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/params.py @@ -2,13 +2,13 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.enums import ( +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.enums import ( DeviceType, FuzzLevel, PerturbationMode, ) -from atat.pytorch.free_benchmark.common.utils import Tools +from msprobe.pytorch.free_benchmark.common.utils import Tools @dataclass @@ -78,7 +78,7 @@ def data_pre_deal(name, func, args, kwargs): data_params.valid_input_index = index if index == -1: logger.warning_on_rank_0( - f"[atat] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}." + f"[msprobe] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}." ) return data_params diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py similarity index 92% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py index 24d25967635b3dcfd1da89e1f54d3282fa1181ed..631beeb85cbce50ba126f39be8c04347f189c50b 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py @@ -1,5 +1,5 @@ import torch -from atat.pytorch.free_benchmark.common.enums import DeviceType +from msprobe.pytorch.free_benchmark.common.enums import DeviceType class Tools: @@ -96,3 +96,7 @@ class TorchC: add = torch._C._VariableFunctionsClass.add bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor clone = torch._C._VariableFunctionsClass.clone + clamp = torch._C._VariableFunctionsClass.clamp + tensor_split = torch._C._VariableFunctionsClass.tensor_split + stack = torch._C._VariableFunctionsClass.stack + reshape = torch._C._VariableFunctionsClass.reshape diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py index 89ef9e4c9b4500953b0edc26f28f5b14e401ca50..6781a1c2fc4c7fd348d330a49352e2f6195e8a71 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py @@ -1,10 +1,10 @@ import torch -from atat.core.common.exceptions import FreeBenchmarkException -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams -from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory -from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( +from msprobe.core.common.exceptions import FreeBenchmarkException +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import CommonField +from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory +from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, ) @@ -41,18 +41,18 @@ class GradSaver: data_processor.update_unequal_rows(handler.get_unequal_rows()) except IndexError: logger.warning_on_rank_0( - f"[atat] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." + f"[msprobe] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}" ) return grad except FreeBenchmarkException as e: logger.warning_on_rank_0( - f"[atat] Free benchmark: grad input check error: {e}" + f"[msprobe] Free benchmark: grad input check error: {e}" ) return grad except Exception as e: logger.warning_on_rank_0( - f"[atat] Free benchmark: grad compare error: {e}" + f"[msprobe] Free benchmark: grad compare error: {e}" ) return grad return grad @@ -77,7 +77,7 @@ class GradSaver: handler.handle(self.data_params) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." + f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." f"{e}" ) # 在扰动前后输出对比后释放输出的引用 diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py index 85aa68f13b969e996407f5f64353de43b916e00f..59239fcd004fb3472d5c8f53305e692fec00adee 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/single_benchmark.py @@ -1,9 +1,9 @@ import math import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import ThresholdConfig -from atat.pytorch.free_benchmark.common.utils import TorchC +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.utils import TorchC class SingleCompare: @@ -28,6 +28,14 @@ class SingleCompare: tensor[inf_or_nan_mask] = 1 return tensor + @staticmethod + def compare_float_seq(actual, golden): + return math.isclose(actual, golden) + + @staticmethod + def compare_other_seq(actual, golden): + return actual == golden + def compare_dict_seq(self, actual, golden): if len(actual) != len(golden): return False @@ -61,7 +69,7 @@ class SingleCompare: actual.dtype, ThresholdConfig.BENCHMARK_THD_DICT.get(torch.float32) ) if self.filter_overflow(golden) > 0: - logger.warning_on_rank_0("[atat] Free Benchmark: inf and nan" + logger.warning_on_rank_0("[msprobe] Free Benchmark: inf and nan" "in golden tensor is not supported.") return True actual = self.replace_inf_or_nan(actual) @@ -76,12 +84,6 @@ class SingleCompare: return False return True - def compare_float_seq(self, actual, golden): - return math.isclose(actual, golden) - - def compare_other_seq(self, actual, golden): - return actual == golden - def _cal_compare_metrics(self, actual, golden): diff_value = TorchC.subtract(actual, golden) diff_abs = TorchC.abs(diff_value) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py similarity index 81% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/main.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py index 2ebc0a6db917334f492019fa694330d86a0f37e1..971776d1326409c8878849e7b09a4614ffbc16f5 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py @@ -1,19 +1,19 @@ from abc import ABC import torch -from atat.core.common.const import Const -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.enums import ( +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import CommonField +from msprobe.pytorch.free_benchmark.common.enums import ( DeviceType, FuzzLevel, HandlerType, PerturbationMode, ) -from atat.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params -from atat.pytorch.free_benchmark.compare.grad_saver import GradSaver -from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory -from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( +from msprobe.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params +from msprobe.pytorch.free_benchmark.compare.grad_saver import GradSaver +from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory +from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, ) @@ -81,7 +81,7 @@ class FreeBenchmarkCheck(ABC): grad_saver = getattr(module, CommonField.GRADSAVER) except AttributeError: logger.warning_on_rank_0( - f"[atat] Free benchmark: get grad saver failed. api_name:{name}" + f"[msprobe] Free benchmark: get grad saver failed. api_name:{name}" ) return @@ -97,6 +97,6 @@ class FreeBenchmarkCheck(ABC): ) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free benchmark: grad vjp calculate failed. api_name:{name} error: {e}" + f"[msprobe] Free benchmark: grad vjp calculate failed. api_name:{name} error: {e}" ) return diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py similarity index 78% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py index aa572fd8e8dc8b62493dfa1fecc587b934c83a99..f64a201d5efa007ff4ed848eafd2ab6db535a2f5 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/base_layer.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any -from atat.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.common.params import DataParams class BaseLayer(ABC): diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py similarity index 62% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py index 0d09438ce04132c9c5c301d758dc06818805082e..0ea9107aa84c2633435fe616891f5386b17de423 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/layer_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py @@ -1,15 +1,15 @@ -from atat.pytorch.free_benchmark import FreeBenchmarkException -from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode -from atat.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import ( +from msprobe.pytorch.free_benchmark import FreeBenchmarkException +from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.improve_precision import ( ImprovePrecisionLayer, ) -from atat.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer -from atat.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer -from atat.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer -from atat.pytorch.free_benchmark.perturbed_layers.npu.change_value import ( +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.add_noise import AddNoiseLayer +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.bit_noise import BitNoiseLayer +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.no_change import NoChangeLayer +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.change_value import ( ChangeValueLayer, ) -from atat.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer +from msprobe.pytorch.free_benchmark.perturbed_layers.run_cpu import CpuLayer class LayerFactory: diff --git a/debug/accuracy_tools/atat/pytorch/functional/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/functional/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py similarity index 78% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index af8a93f7d4b9b06623b70c22e7fb5065305e84a0..a18ef1c51bd342c9b3ab5ffecf14c307e9be5527 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -1,10 +1,10 @@ import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import ThresholdConfig -from atat.pytorch.free_benchmark.common.enums import PerturbationMode -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.common.utils import TorchC +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -37,7 +37,7 @@ class AddNoiseLayer(NpuBaseLayer): 对输入添加扰动并返回 """ logger.info_on_rank_0( - f"[atat] Free benchmark: Perturbation is " + f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.ADD_NOISE} of {self.api_name}." ) params.perturbed_value = self.add_noise(params.args[params.valid_input_index]) @@ -60,13 +60,13 @@ class AddNoiseLayer(NpuBaseLayer): """ if not self.perturbed_value: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"dtype unsupported. Cancel perturbation." ) return False if tensor_obj.numel() == 0: logger.warning_on_rank_0( - f"[atat] Free benchmark: For {self.api_name}, tensor shape must > 0." + f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0." f" Cancel adding noise." ) return False @@ -77,13 +77,13 @@ class AddNoiseLayer(NpuBaseLayer): max_val = TorchC.max(TorchC.abs(tensor_obj)).item() except Exception: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"when calculate maximun value, tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"Maximun value is less than the minimun threshold. Cancel add noise." ) return False diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py similarity index 80% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index 40b99acf41105fa61792ef52e27cc7f2e6686ba7..45dea7b93a5c7628b24bf0470af10af355a7742f 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -1,10 +1,10 @@ import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import ThresholdConfig -from atat.pytorch.free_benchmark.common.enums import PerturbationMode -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.common.utils import TorchC +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -53,7 +53,7 @@ class BitNoiseLayer(NpuBaseLayer): 对输入添加扰动并返回 """ logger.info_on_rank_0( - f"[atat] Free benchmark: Perturbation is " + f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.BIT_NOISE} of {self.api_name}." ) params.perturbed_value = self.add_bit_noise(params.args[params.valid_input_index]) @@ -65,13 +65,13 @@ class BitNoiseLayer(NpuBaseLayer): """ if not self.bit_type: logger.info_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"dtype unsupported. Cancel perturbation." ) return False if tensor_obj.numel() == 0: logger.warning_on_rank_0( - f"[atat] Free benchmark: For {self.api_name}, tensor shape must > 0" + f"[msprobe] Free benchmark: For {self.api_name}, tensor shape must > 0" f" Cancel adding noise." ) return False @@ -82,13 +82,13 @@ class BitNoiseLayer(NpuBaseLayer): max_val = TorchC.max(TorchC.abs(tensor_obj)).item() except Exception: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"when calculate maximun value, tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: logger.info_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"Maximun value is less than the minimun threshold. Cancel add noise." ) return False diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py similarity index 81% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index b7a967e18b91ecc2d36c22afce49f72677bef565..91085d57a68b4841b2e04453c05c41a2903477c3 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -1,9 +1,9 @@ import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.enums import PerturbationMode -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.common.utils import TorchC -from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.common.utils import TorchC +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -44,7 +44,7 @@ class ChangeValueLayer(NpuBaseLayer): 对输入添加扰动并返回 """ logger.info_on_rank_0( - f"[atat] Free benchmark: Perturbation is " + f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}." ) params.perturbed_value = self.change_value(params.args[params.valid_input_index]) @@ -56,7 +56,7 @@ class ChangeValueLayer(NpuBaseLayer): """ if tensor_obj.size(0) < 2: logger.info_on_rank_0( - f"[atat] Free Benchmark: For {self.api_name}, " + f"[msprobe] Free Benchmark: For {self.api_name}, " f"size 0 must greater than 1. Cancel change value." ) return False diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py similarity index 83% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index 03718e3c4d6c4c4ea28ae6eec5daddc02bcedb7d..ad6d8b8989d6983f81a9a2d58798a26d4ccc45c1 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -1,10 +1,10 @@ import torch -from atat.core.common.const import Const -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import CommonField -from atat.pytorch.free_benchmark.common.enums import PerturbationMode -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import CommonField +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -34,7 +34,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): def handle(self, params: DataParams) -> torch.Any: logger.info_on_rank_0( - f"[atat] Free benchmark: Perturbation is " + f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}." ) new_args = self.improve_tensor_precision(params.args) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py similarity index 64% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py index bb065385c690f937c702cadac5707b787489aee5..a69c56002a205a518a6929835591859f63b800ff 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -1,8 +1,8 @@ import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.enums import PerturbationMode -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.enums import PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( NpuBaseLayer, ) @@ -21,7 +21,7 @@ class NoChangeLayer(NpuBaseLayer): 对输入添加扰动并返回 """ logger.info_on_rank_0( - f"[atat] Free benchmark: Perturbation is " + f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.NO_CHANGE} of {self.api_name}." ) params.perturbed_value = self.no_change(params.args[params.valid_input_index]) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py similarity index 90% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py index 3784af0953022f1eb981ca26cf88765044f56f3f..1a859481475bc9963a5d3b96389061a257a1a759 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py @@ -2,8 +2,8 @@ from abc import abstractmethod from typing import Any import torch -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer class NpuBaseLayer(BaseLayer): diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py similarity index 52% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py index 024958ffbe126b89ec15fa10b277d90af4ed3e45..d34ac976537d794a05255a32de8d54de2dbac5d3 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py @@ -1,9 +1,9 @@ import torch -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.common.utils import Tools -from atat.pytorch.free_benchmark.common.enums import DeviceType -from atat.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.common.utils import Tools +from msprobe.pytorch.free_benchmark.common.enums import DeviceType +from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer class CpuLayer(BaseLayer): @@ -11,7 +11,7 @@ class CpuLayer(BaseLayer): def handle(self, params: DataParams) -> torch.Any: logger.info_on_rank_0( - f"[atat] Free benchmark: Perturbation is to_cpu of {self.api_name}." + f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}." ) new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True) new_kwargs = Tools.convert_device_and_dtype(params.kwargs, DeviceType.CPU, change_dtype=True) diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/__init__.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/parse_tool/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py similarity index 64% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py index c57d7e390a0fad73a5e87d72ec72e434835618cf..e36f586735538ed2d51b3afaaec390724d9a4b02 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py @@ -1,22 +1,23 @@ import math from abc import ABC, abstractmethod from typing import Any, Optional, Tuple +import numpy as np import torch -from atat.core.common.const import Const -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import ThresholdConfig -from atat.pytorch.free_benchmark.common.enums import ( +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.enums import ( FuzzThreshold, NormType, PerturbationMode, ) -from atat.pytorch.free_benchmark.common.params import ( +from msprobe.pytorch.free_benchmark.common.params import ( DataParams, HandlerParams, make_unequal_row, ) -from atat.pytorch.free_benchmark.common.utils import Tools, TorchC +from msprobe.pytorch.free_benchmark.common.utils import Tools, TorchC class FuzzHandler(ABC): @@ -34,15 +35,36 @@ class FuzzHandler(ABC): origin_ouput = origin_ouput.values perturbed_output = perturbed_output.values if hasattr(perturbed_output, "dtype"): - abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype) + abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD) else: - abs_tol = FuzzThreshold.F32_THD.value + abs_tol = FuzzThreshold.F32_THD return ( origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device), perturbed_output, abs_tol, ) + @staticmethod + def tensor_split_for_error_calculate(origin_output, perturbed_output): + """ + 对将投入误差值计算的扰动前后输出张量进行分块 + :param origin_output: 原始输出 + :param perturbed_output: 扰动后输出 + :return origin_output_chunks: 切块后原始输出列表 + :return perturbed_output_chunks: 切块后扰动后输出列表 + """ + single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB + if single_output_mem == 0 or origin_output.ndim == 0: + return [origin_output], [perturbed_output] + # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb]) + chunks_exp = int(math.log(single_output_mem, 2)) - 4 + chunks = 2 ** chunks_exp + chunks = max(chunks, 1) + chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK) + origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks) + perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks) + return origin_output_chunks, perturbed_output_chunks + @staticmethod def convert_overflow_ratio_to_consistent(ratio): if math.isnan(ratio) or math.isinf(ratio): @@ -61,36 +83,28 @@ class FuzzHandler(ABC): self, origin_output, perturbed_output, norm_type, abs_tol ): if norm_type == NormType.ENDLESS_NORM: - return self.get_endless_norm(origin_output, perturbed_output, abs_tol) + return self.calculate_error(origin_output, perturbed_output, abs_tol) return ThresholdConfig.COMP_CONSISTENT - def get_endless_norm(self, origin_output, perturbed_output, abs_tol): - ratio_tensor1 = TorchC.where( - TorchC.gt(TorchC.abs(perturbed_output), abs_tol), - TorchC.div( - TorchC.abs(origin_output), - TorchC.add(TorchC.abs(perturbed_output), abs_tol), - ), - 1, - ) - ratio_tensor2 = TorchC.where( - TorchC.gt(TorchC.abs(origin_output), abs_tol), - TorchC.div( - TorchC.abs(perturbed_output), - TorchC.add(TorchC.abs(origin_output), abs_tol), - ), - 1, - ) + def calculate_error(self, origin_output, perturbed_output, abs_tol): + origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output) + norm1 = -np.inf + norm2 = -np.inf + norm3 = np.inf + for i, chunk_origin in enumerate(origin_output_chunks): + if chunk_origin.nelement() == 0: + break + chunk_perturbed = perturbed_output_chunks[i] + ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol, + TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1) + ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol, + TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1) + norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)]) + max_ratio1, max_ratio2 = norm_values.tolist() + norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1)) + norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2)) + norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1)) - norm1 = self.convert_overflow_ratio_to_consistent( - TorchC.max(ratio_tensor1).item() - ) - norm2 = self.convert_overflow_ratio_to_consistent( - TorchC.max(ratio_tensor2).item() - ) - norm3 = self.convert_overflow_ratio_to_consistent( - TorchC.min(ratio_tensor1).item() - ) if norm3 < 0: ratio = ThresholdConfig.SYMBOL_FLIPPING else: @@ -104,7 +118,7 @@ class FuzzHandler(ABC): ) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " + f"[msprobe] Free Benchmark: For {self.params.api_name}, " f"when computing ratio," f" y1 or y2 dtype is not supported {e}" ) @@ -133,7 +147,7 @@ class FuzzHandler(ABC): ) elif not isinstance(perturbed_output, torch.Tensor): logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name} " + f"[msprobe] Free Benchmark: For {self.params.api_name} " f"The compare for output type {type(perturbed_output)} is not supported" ) @@ -185,7 +199,7 @@ class FuzzHandler(ABC): ) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " + f"[msprobe] Free Benchmark: For {self.params.api_name}, " f"when campare the result exception raise {e}" ) return npu_consistent, max_fuzz_ratio diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py similarity index 68% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py index ed846803a180e3cca7b0a05f01f592dff07d5a44..c16284eb07beda10a38755dc54349c8835ada37a 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/check_handler.py @@ -1,11 +1,11 @@ from typing import Any -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.enums import DeviceType -from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row -from atat.pytorch.free_benchmark.common.utils import Tools -from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare -from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.enums import DeviceType +from msprobe.pytorch.free_benchmark.common.params import DataParams, make_unequal_row +from msprobe.pytorch.free_benchmark.common.utils import Tools +from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare +from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler class CheckerHandler(FuzzHandler): @@ -33,7 +33,7 @@ class CheckerHandler(FuzzHandler): self.other_compare(data_params) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " + f"[msprobe] Free Benchmark: For {self.params.api_name}, " f"when campare the result exception raise {e}" ) return data_params.original_result diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py similarity index 60% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py index fa5c6f37495323693175117040c9c2f7fa3c01c6..a1d90035e847abb26c0838635666ff0425853513 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py @@ -1,9 +1,9 @@ from typing import Any -from atat.pytorch.free_benchmark.common.params import DataParams -from atat.pytorch.free_benchmark.common.utils import Tools -from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler -from atat.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.params import DataParams +from msprobe.pytorch.free_benchmark.common.utils import Tools +from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler +from msprobe.pytorch.free_benchmark import logger class FixHandler(FuzzHandler): @@ -18,7 +18,7 @@ class FixHandler(FuzzHandler): ) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name} " + f"[msprobe] Free Benchmark: For {self.params.api_name} " f"Fix output failed. " ) return data_params.original_result \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py similarity index 59% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py index cff629854d9b2bd1413b273171ad4ce73493bbb0..5ee968c6a86728786f526660594fbb6de4ce18ee 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py @@ -1,10 +1,10 @@ -from atat.pytorch.free_benchmark import FreeBenchmarkException -from atat.pytorch.free_benchmark.common.constant import PreheatConfig -from atat.pytorch.free_benchmark.common.enums import HandlerType -from atat.pytorch.free_benchmark.common.params import HandlerParams -from atat.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler -from atat.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler -from atat.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler +from msprobe.pytorch.free_benchmark import FreeBenchmarkException +from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig +from msprobe.pytorch.free_benchmark.common.enums import HandlerType +from msprobe.pytorch.free_benchmark.common.params import HandlerParams +from msprobe.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler +from msprobe.pytorch.free_benchmark.result_handlers.preheat_handler import PreheatHandler +from msprobe.pytorch.free_benchmark.result_handlers.fix_handler import FixHandler class FuzzHandlerFactory: diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py similarity index 88% rename from debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py rename to debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py index 033a6d4931f718fdd2b60b3ad502766ea3f06382..d78e4303620f3ca73522cc9188452ffce0de2b12 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py @@ -1,14 +1,14 @@ import math from typing import Any -from atat.pytorch.free_benchmark import logger -from atat.pytorch.free_benchmark.common.constant import ThresholdConfig -from atat.pytorch.free_benchmark.common.counter import preheat_counter -from atat.pytorch.free_benchmark.common.enums import DeviceType -from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams -from atat.pytorch.free_benchmark.common.utils import Tools -from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare -from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler +from msprobe.pytorch.free_benchmark import logger +from msprobe.pytorch.free_benchmark.common.constant import ThresholdConfig +from msprobe.pytorch.free_benchmark.common.counter import preheat_counter +from msprobe.pytorch.free_benchmark.common.enums import DeviceType +from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from msprobe.pytorch.free_benchmark.common.utils import Tools +from msprobe.pytorch.free_benchmark.compare.single_benchmark import SingleCompare +from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler class PreheatHandler(FuzzHandler): @@ -74,14 +74,14 @@ class PreheatHandler(FuzzHandler): cpu_consistent = self.compare_npu_and_cpu(data_params) except Exception as e: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " + f"[msprobe] Free Benchmark: For {self.params.api_name}, " f"when campare to cpu exception raise {e}" ) try: first_dtype = Tools.get_first_tensor_dtype(data_params.original_result) except RuntimeError: logger.warning_on_rank_0( - f"[atat] Free Benchmark: For {self.params.api_name}, " + f"[msprobe] Free Benchmark: For {self.params.api_name}, " f"the output sequence does not contain tensors." ) if preheat_counter.get_api_preheat(self.pure_name, str(first_dtype)): @@ -96,7 +96,7 @@ class PreheatHandler(FuzzHandler): if res: total_count = preheat_counter.get_one_step_used_api(self.pure_name) logger.info_on_rank_0( - f"[atat] Free benchmark: preheat sample in step{self.params.step}" + f"[msprobe] Free benchmark: preheat sample in step{self.params.step}" f"api_name {self.params.api_name}, " f"curr_called_seq: {curr_called_seq}/{total_count}" ) diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fd8bfd0cbbad21e688c44c6fbb8671ba942511 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py @@ -0,0 +1,75 @@ +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w +from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \ + npu_confusion_transpose_backward +from msprobe.pytorch.bench_functions.fast_gelu import fast_gelu, npu_fast_gelu_backward +from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval +from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward +from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward +from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad +from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward +from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward +from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \ + npu_scaled_masked_softmax_backward +from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish + + +class Register(dict): + def __init__(self, *args, **kwargs): + super(Register, self).__init__(*args, **kwargs) + self._dict = {} + + def __call__(self, target_func_list): + for target in target_func_list: + self.register(target) + return + + def __setitem__(self, key, value): + self._dict[key] = value + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def __str__(self): + return str(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + def register(self, target): + + def add_register_item(key, value): + if key in self._dict: + logger.warning(f"{value.__name__} has been registered before, so we will overriden it.") + self[key] = value + return value + + if callable(target): + return add_register_item(target.__name__, target) + else: + raise Exception(f"The func {target} is not callable.") + + +# register for npu custom bench functions +npu_custom_functions = Register() +npu_custom_functions([ + npu_apply_adam_w, npu_confusion_transpose, fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention, + npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu +]) + +# register for npu custom backward bench functions +npu_custom_grad_functions = Register() +npu_custom_grad_functions([ + npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward, + npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward, + npu_swiglu_backward +]) diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/__init__.py b/debug/accuracy_tools/msprobe/pytorch/functional/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/functional/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/msprobe/pytorch/functional/data_processor.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/functional/data_processor.py rename to debug/accuracy_tools/msprobe/pytorch/functional/data_processor.py diff --git a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py similarity index 73% rename from debug/accuracy_tools/atat/pytorch/functional/dump_module.py rename to debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py index 675fa2a1bfdfdef9b12bf99f2428e3236c86a906..efb95c3369f6cda2f883d70a86261e0232535f86 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py @@ -1,10 +1,10 @@ import torch.nn as nn -from atat.pytorch.common.log import logger -from atat.core.common.const import Const -from atat.pytorch.hook_module.api_registry import api_register -from atat.pytorch.debugger.precision_debugger import PrecisionDebugger -from atat.core.common.exceptions import MsaccException -from atat.core.data_dump.scope import BaseScope +from msprobe.pytorch.common.log import logger +from msprobe.core.common.const import Const +from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.debugger.precision_debugger import PrecisionDebugger +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.data_dump.scope import BaseScope module_count = {} @@ -12,10 +12,10 @@ module_count = {} def module_dump(module, dump_name): if not isinstance(module, nn.Module): logger.error("The parameter:module in module_dump is not a Module subclass.") - raise MsaccException(MsaccException.INVALID_PARAM_ERROR) + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) if not isinstance(dump_name, str): logger.error("The parameter:dump_name in module_dump is not a str type.") - raise MsaccException(MsaccException.INVALID_PARAM_ERROR) + raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) api_register.api_originality() if dump_name not in module_count: module_count[dump_name] = 0 diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/__init__.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/hook_module/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py similarity index 91% rename from debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py index 3b971cc71ecaf1d229337f4acf5afab6b5b0f9db..f75201eafcda40c61b2c5c3da710d6cfb06719b8 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_registry.py @@ -18,15 +18,15 @@ import torch import torch.distributed as dist -from atat.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten -from atat.pytorch.hook_module.wrap_aten import get_aten_ops -from atat.pytorch.hook_module.wrap_distributed import get_distributed_ops -from atat.pytorch.hook_module.wrap_functional import get_functional_ops -from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops -from atat.pytorch.hook_module.wrap_torch import get_torch_ops -from atat.pytorch.hook_module.wrap_vf import get_vf_ops -from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu -from atat.core.common.const import Const +from msprobe.pytorch.hook_module import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten +from msprobe.pytorch.hook_module.wrap_aten import get_aten_ops +from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops +from msprobe.pytorch.hook_module.wrap_functional import get_functional_ops +from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops +from msprobe.pytorch.hook_module.wrap_torch import get_torch_ops +from msprobe.pytorch.hook_module.wrap_vf import get_vf_ops +from msprobe.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu +from msprobe.core.common.const import Const torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index 57212b6e45c572db3d3a235035647bbb614a3cb3..ff6427e51e5c6bc6b715991979890f759ab955cf 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -17,10 +17,12 @@ import functools import threading + import torch import torch.nn as nn import torch.utils.hooks as full_hooks -from atat.core.common.const import Const + +from msprobe.core.common.const import Const class HOOKModule(nn.Module): @@ -61,6 +63,10 @@ class HOOKModule(nn.Module): HOOKModule.inner_stop_hook[self.current_thread] = False return result + @classmethod + def reset_module_stats(cls): + cls.module_count = {} + def _call_func(self, *input, **kwargs): full_backward_hooks, non_full_backward_hooks = [], [] if len(self._backward_hooks) > 0: diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml similarity index 99% rename from debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml rename to debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index d64c577ff38b7b7e3478d59eb1754845973c7103..f68708e945ea5351c84be982250b1dde436f3ba1 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -1873,4 +1873,5 @@ distributed: - reduce_scatter - _reduce_scatter_base - _all_gather_base - - all_to_all_single \ No newline at end of file + - all_to_all_single + - all_to_all \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/utils.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py similarity index 95% rename from debug/accuracy_tools/atat/pytorch/hook_module/utils.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py index e4ed157af6dcafc826eb74fcc40898bfdc835eac..d991445db091f74fbe9f5f7d5f68dd4bdcf4b868 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py @@ -18,7 +18,7 @@ import os import yaml -from atat.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py similarity index 71% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py index c5a3c6365d1841927c3acd2993ac2828092ae80b..a99d669beb227c0576d95991d5e89811d229c4c3 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py @@ -20,16 +20,18 @@ import torch import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard -from atat.core.common.const import Const -from atat.core.common.file_check import FileOpen - +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.common.utils import torch_device_guard +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen +from msprobe.pytorch.function_factory import npu_custom_grad_functions cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") with FileOpen(yaml_path, 'r') as f: - WrapAtenOps = yaml.safe_load(f).get('aten') + Ops = yaml.safe_load(f) + WrapAtenOps = Ops.get('aten') + WhiteAtenOps = Ops.get('white_aten_ops', []) aten_func = {} @@ -48,7 +50,7 @@ class HOOKAtenOP(object): class AtenOPTemplate(HOOKModule): - def __init__(self, op, hook): + def __init__(self, op, hook, need_hook=True): if isinstance(op, torch._ops.OpOverloadPacket): op_name_ = op._qualified_op_name.split("::")[-1] else: @@ -58,10 +60,21 @@ class AtenOPTemplate(HOOKModule): op_name_ = op_name_ + '.' + overload_name self.op = op self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP - super().__init__(hook) + self.need_hook = need_hook + if self.need_hook: + super().__init__(hook) @torch_device_guard def forward(self, *args, **kwargs): + if isinstance(self.op, str): + if self.op in npu_custom_grad_functions: + return npu_custom_grad_functions[self.op](*args, **kwargs) + if self.op in WhiteAtenOps: + return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs) + if self.op not in aten_func: + raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not " + f"in dir(torch.ops.aten) and support yaml.") + return aten_func[self.op](*args, **kwargs) return self.op(*args, **kwargs) @@ -80,13 +93,13 @@ class AtenOPPacketTemplate(): else: return attr - def overloads(self): - return self.opPacket.overloads() - @torch_device_guard def __call__(self, *args, **kwargs): return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs) + def overloads(self): + return self.opPacket.overloads() + def wrap_aten_op(op, hook): return AtenOPPacketTemplate(op, hook) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py similarity index 91% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py index e02189ac1bf2a5d2e1607f2625200eb4ee2ea6e8..54afecb9a663db8aa2db998247fca707a8aa2ced 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py @@ -20,10 +20,10 @@ from functools import wraps import torch.distributed as dist import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard -from atat.core.common.const import Const -from atat.core.common.file_check import FileOpen +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.common.utils import torch_device_guard +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py similarity index 94% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py index fa97f5ee3106c62ab63d80e2b2ebe349494ce695..96d6986a0784cf98e42fd04d45a52f0be6bd39ed 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py @@ -20,11 +20,11 @@ import os import torch import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard -from atat.core.common.const import Const -from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileOpen +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.common.utils import torch_device_guard +from msprobe.core.common.const import Const +from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_utils import FileOpen def remove_dropout(): diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py similarity index 71% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py index 7d0882804f478d8657faeb5f90b0b718914d72e5..607b4ed3cd54a7e745fa1eca715b1ef8ba1e04b2 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py @@ -17,19 +17,26 @@ import os import torch -import torch_npu import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version -from atat.core.common.const import Const -from atat.core.common.file_check import FileOpen +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen +from msprobe.pytorch.function_factory import npu_custom_functions cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") with FileOpen(yaml_path, 'r') as f: WrapNpuOps = yaml.safe_load(f).get('torch_npu') +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + def get_npu_ops(): global WrapNpuOps @@ -46,13 +53,19 @@ class HOOKNpuOP(object): class NpuOPTemplate(HOOKModule): - def __init__(self, op_name, hook): + def __init__(self, op_name, hook, need_hook=True): self.op_name_ = op_name self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP - super().__init__(hook) + self.need_hook = need_hook + if need_hook: + super().__init__(hook) @torch_device_guard def forward(self, *args, **kwargs): + if not self.need_hook: + if self.op_name_ not in npu_custom_functions: + raise Exception(f'There is not bench function {self.op_name_}') + return npu_custom_functions[self.op_name_](*args, **kwargs) if torch_without_guard_version: return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) else: @@ -60,7 +73,6 @@ class NpuOPTemplate(HOOKModule): def wrap_npu_op(op_name, hook): - def npu_op_template(*args, **kwargs): return NpuOPTemplate(op_name, hook)(*args, **kwargs) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py index 6fac18140238e016004c2329f0d8ff647045c0c0..aba6f86148967f0f5d0af4d9ac2a85150e79ff47 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py @@ -20,10 +20,10 @@ import os import torch import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, parameter_adapter -from atat.core.common.const import Const -from atat.core.common.file_check import FileOpen +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py similarity index 91% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py index f0bd01fe4624ccc1a50d1f3e1fb5d614b30a7a8d..3f9518b7f1a3b9532de0ad016b21246be5bd883d 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py @@ -20,10 +20,10 @@ import os import torch import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard -from atat.core.common.const import Const -from atat.core.common.file_check import FileOpen +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.common.utils import torch_device_guard +from msprobe.core.common.const import Const +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py similarity index 88% rename from debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py index d4c570221d4b219be168b25dafd76264769da197..351820fd6c11f71b7465e1e5e456b582161ff462 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py @@ -20,10 +20,10 @@ import os import torch import yaml -from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.core.common.file_check import FileOpen -from atat.pytorch.common.utils import torch_device_guard -from atat.core.common.const import Const +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.core.common.file_utils import FileOpen +from msprobe.pytorch.common.utils import torch_device_guard +from msprobe.core.common.const import Const cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/module_processer.py similarity index 72% rename from debug/accuracy_tools/atat/pytorch/module_processer.py rename to debug/accuracy_tools/msprobe/pytorch/module_processer.py index 8ce9140e32cbc9f7d2d62eede164e24127a40f34..3e9969d32d9147a6f26b7a6a4219364368a6116b 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/module_processer.py @@ -1,15 +1,17 @@ from functools import wraps + import torch from torch.utils.hooks import BackwardHook -from atat.core.common.const import Const -from atat.core.data_dump.scope import ModuleRangeScope + +from msprobe.core.common.const import Const +from msprobe.core.data_dump.scope import ModuleRangeScope class ModuleProcesser: + module_count = {} module_stack = [] api_parent_node = "" module_node = {} - current_module_name = "" def __init__(self, scope): if isinstance(scope, ModuleRangeScope): @@ -19,15 +21,22 @@ class ModuleProcesser: BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook) - self.module_count = {} @staticmethod def filter_tensor_and_tuple(func): @wraps(func) def wrap_by_filter_tensor_and_tuple(*args, **kwargs): - # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入 + # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] if not isinstance(args[1], (torch.Tensor, tuple)): + for item_str in dir(args[1]): + item = getattr(args[1], item_str) + # 处理tensor或者只包含tensor的元组 + if isinstance(item, torch.Tensor) or \ + (isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)): + args_new = (args[0], item) + result = func(*args_new, **kwargs) + setattr(args[1], item_str, result) return args[1] return func(*args, **kwargs) @@ -55,11 +64,26 @@ class ModuleProcesser: else: return result + @staticmethod + def module_count_func(module_name): + if module_name not in ModuleProcesser.module_count: + ModuleProcesser.module_count[module_name] = 0 + else: + ModuleProcesser.module_count[module_name] += 1 + return ModuleProcesser.module_count[module_name] + + @classmethod + def reset_module_stats(cls): + cls.module_count = {} + cls.module_stack = [] + cls.api_parent_node = "" + cls.module_node = {} + def node_hook(self, name_prefix, start_or_stop, **kwargs): def pre_hook(module, input, output=None): try: - index = self.module_count_func(name_prefix) + index = ModuleProcesser.module_count_func(name_prefix) except IndexError as e: index = None pass @@ -89,10 +113,3 @@ class ModuleProcesser: return pre_hook else: return end_hook - - def module_count_func(self, module_name): - if module_name not in self.module_count: - self.module_count[module_name] = 0 - else: - self.module_count[module_name] += 1 - return self.module_count[module_name] diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py similarity index 96% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py index e6d55ca0614767338023673b6d0cfa5c7d990c0f..4e3d574cd84118dda2e32667be09d8695312a584 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py @@ -6,12 +6,11 @@ import json from collections import namedtuple from rich.table import Table from rich.console import Console -from .single_compare import single_benchmark_compare_wrap -from .utils import DispatchException -from atat.core.common.const import CompareConst -from atat.core.common.file_check import FileOpen -from atat.pytorch.common.log import logger -from atat.core.common.utils import CompareException +from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap +from msprobe.core.common.const import CompareConst +from msprobe.core.common.file_utils import FileOpen +from msprobe.pytorch.common.log import logger +from msprobe.core.common.utils import CompareException ELEMENT_NUM_THRESHOLD = 100 ZERO_NUM_THRESHOLD = 0.1 @@ -228,7 +227,7 @@ class Comparator: else: is_bwd_success, bwd_compare_alg_results = True, None if is_bwd_success and bwd_compare_alg_results is None: - self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results, + self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results, bwd_compare_alg_results)) else: self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py similarity index 94% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index 7502d746acf38a52b427659b8d97f5f55a4ce96c..2251fa6cb6aabb0054598cbeb0354d3ac9552c7a 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -16,14 +16,14 @@ except ImportError: else: is_npu = True -from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ +from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ DispatchRunParam, DisPatchDataInfo -from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \ +from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \ DispatchException -from .compare import Comparator -from atat.core.common.file_check import FileOpen -from atat.core.common.utils import check_file_or_directory_path, check_path_before_create -from atat.core.common.const import Const, CompareConst +from msprobe.pytorch.online_dispatch.compare import Comparator +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create +from msprobe.core.common.const import Const, CompareConst current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" @@ -209,9 +209,9 @@ class PtdbgDispatch(TorchDispatchMode): time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) if tag is None or not isinstance(tag, str): logger_warn('There is not tag or the type of tag is not string.') - dir_name = f'atat_rank{self.device_id}_{time_now}' + dir_name = f'msprobe_rank{self.device_id}_{time_now}' else: - dir_name = f'atat_{tag}_rank{self.device_id}_{time_now}' + dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}' return dir_name def load_yaml_file(self, file_path): diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py similarity index 95% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index cd7c5a3f282d17b4e064a4b00928861b2e32f737..5e8bf4f1117b0a86f67902df723ca08ad19db73e 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -5,11 +5,11 @@ from datetime import datetime, timezone import pandas as pd import torch -from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ +from msprobe.pytorch.online_dispatch.utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ COLOR_RESET, CSV_COLUMN_NAME -from atat.core.common.file_check import FileOpen, change_mode -from atat.core.common.const import CompareConst, FileCheckConst, Const -from atat.pytorch.common.log import logger +from msprobe.core.common.file_utils import FileOpen, change_mode +from msprobe.core.common.const import CompareConst, FileCheckConst, Const +from msprobe.pytorch.common.log import logger class DispatchRunParam: def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator): diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/single_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/single_compare.py rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/single_compare.py diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/torch_ops_config.yaml b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/torch_ops_config.yaml similarity index 100% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/torch_ops_config.yaml rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/torch_ops_config.yaml diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py similarity index 97% rename from debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py rename to debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py index f3fcffb6f26adbd2b2ff6b77da720f7ecdfcb0ca..c1d1e841a40137508c1fd09d617d57c83e9306a9 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py @@ -12,8 +12,8 @@ except ImportError: else: pta_cpu_device = torch.device("cpu") -from atat.core.common.const import CompareConst, FileCheckConst -from atat.core.common.file_check import change_mode +from msprobe.core.common.const import CompareConst, FileCheckConst +from msprobe.core.common.file_utils import change_mode cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' diff --git a/debug/accuracy_tools/atat/pytorch/parse.py b/debug/accuracy_tools/msprobe/pytorch/parse.py similarity index 50% rename from debug/accuracy_tools/atat/pytorch/parse.py rename to debug/accuracy_tools/msprobe/pytorch/parse.py index 40792d0e0297a9b034f186e255193d6201517764..efd3d4a2ddb807e23ba346b6c80ef344d560050d 100644 --- a/debug/accuracy_tools/atat/pytorch/parse.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse.py @@ -1,4 +1,4 @@ -from atat.pytorch.parse_tool import cli +from msprobe.pytorch.parse_tool import cli if __name__ == '__main__': cli.parse() diff --git a/debug/accuracy_tools/kj600/kj600/__init__.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/__init__.py similarity index 100% rename from debug/accuracy_tools/kj600/kj600/__init__.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/__init__.py diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/cli.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/parse_tool/cli.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/cli.py index f59fbf13a8d3e2785611022cd7b5b9a2926ea008..500e8eef6846b1209817083992a5e3630381b7dc 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/cli.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/cli.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -from atat.pytorch.parse_tool.lib.interactive_cli import InteractiveCli -from atat.pytorch.common.log import logger +from msprobe.pytorch.parse_tool.lib.interactive_cli import InteractiveCli +from msprobe.pytorch.common.log import logger def _run_interactive_cli(cli=None): diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/__init__.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/compare.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py similarity index 92% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/compare.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py index dfc4529414cbe00307b36fc58e5d62d64c6fdf32..2b091c59e8caef33576a7f7c79efcdbc0987afbe 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/compare.py @@ -19,9 +19,9 @@ import os import time import numpy as np from collections import namedtuple -from atat.pytorch.parse_tool.lib.utils import Util -from atat.pytorch.parse_tool.lib.config import Const -from atat.pytorch.parse_tool.lib.parse_exception import ParseException +from msprobe.pytorch.parse_tool.lib.utils import Util +from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException class Compare: @@ -83,16 +83,17 @@ class Compare: (left, right, save_txt, rl, al, diff_count) = args if left is None or right is None: raise ParseException("invalid input or output") - try: - left_data = np.load(left) - right_data = np.load(right) - except UnicodeError as e: - self.log.error("%s %s" % ("UnicodeError", str(e))) - self.log.warning("Please check the npy file") - raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e - except IOError: - self.log.error("Failed to load npy %s or %s." % (left, right)) - raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e + if self.util.check_path_valid(left) and self.util.check_path_valid(right): + try: + left_data = np.load(left) + right_data = np.load(right) + except UnicodeError as e: + self.log.error("%s %s" % ("UnicodeError", str(e))) + self.log.warning("Please check the npy file") + raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e + except IOError: + self.log.error("Failed to load npy %s or %s." % (left, right)) + raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e # save to txt if save_txt: @@ -157,8 +158,10 @@ class Compare: return res def compare_npy(self, file, bench_file, output_path): - data = np.load(file) - bench_data = np.load(bench_file) + if self.util.check_path_valid(file): + data = np.load(file) + if self.util.check_path_valid(bench_file): + bench_data = np.load(bench_file) shape, dtype = data.shape, data.dtype bench_shape, bench_dtype = bench_data.shape, bench_data.dtype filename = os.path.basename(file) diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/config.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py similarity index 98% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/config.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py index a745ff46f08a28c39c989a5d8dce4ff5cf475ee5..a9a8b2b00e2703ae28d6beff114848de099f3900 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/config.py @@ -33,7 +33,7 @@ class Const: OFFLINE_DUMP_CONVERT_PATTERN = \ r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \ r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$" - NUMPY_PATTERN = r".*\.npy$" + NUMPY_PATTERN = r"^[\w\-_-]\.npy$" NPY_SUFFIX = ".npy" PKL_SUFFIX = ".pkl" DIRECTORY_LENGTH = 4096 diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/file_desc.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/file_desc.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/file_desc.py diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/interactive_cli.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py similarity index 93% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/interactive_cli.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py index 12b07183fbc2e1c2ea630f05ac44deda744d4d01..1ea7dd30153e458b758dc0a79779b54a25fe8289 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/interactive_cli.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/interactive_cli.py @@ -16,10 +16,10 @@ """ import cmd import argparse -from atat.pytorch.parse_tool.lib.parse_tool import ParseTool -from atat.pytorch.parse_tool.lib.utils import Util -from atat.pytorch.parse_tool.lib.config import Const -from atat.pytorch.parse_tool.lib.parse_exception import catch_exception +from msprobe.pytorch.parse_tool.lib.parse_tool import ParseTool +from msprobe.pytorch.parse_tool.lib.utils import Util +from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception class InteractiveCli(cmd.Cmd): diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_exception.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py similarity index 96% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_exception.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py index 1177c51985dc82fba632898590762e38387603ab..7525230cedc7ff11d4112a55998c6414e8f09217 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_exception.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_exception.py @@ -15,7 +15,7 @@ # limitations under the License. """ import logging -from atat.core.common.exceptions import FileCheckException +from msprobe.core.common.exceptions import FileCheckException class ParseException(Exception): diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_tool.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py similarity index 95% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_tool.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py index 3e02baa1272199a961fd550a0837d68100de8348..9a47dc54cf9e15d65eb5fb8d3bae54358777051c 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/parse_tool.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/parse_tool.py @@ -18,11 +18,11 @@ import argparse import os from collections import namedtuple -from atat.pytorch.parse_tool.lib.config import Const -from atat.pytorch.parse_tool.lib.utils import Util -from atat.pytorch.parse_tool.lib.compare import Compare -from atat.pytorch.parse_tool.lib.visualization import Visualization -from atat.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException +from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.utils import Util +from msprobe.pytorch.parse_tool.lib.compare import Compare +from msprobe.pytorch.parse_tool.lib.visualization import Visualization +from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException class ParseTool: diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py similarity index 94% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index ce42d242ba29ef4dcf7f2af042242b1b9987de42..a8abec2d15634d08429e51899a83d4938c4a8869 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -25,15 +25,15 @@ import csv import time import numpy as np from collections import namedtuple -from atat.pytorch.parse_tool.lib.config import Const -from atat.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc -from atat.pytorch.parse_tool.lib.parse_exception import ParseException -from atat.core.common.file_check import change_mode, check_other_user_writable,\ +from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc +from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException +from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\ check_path_executable, check_path_owner_consistent -from atat.core.common.const import FileCheckConst -from atat.core.common.file_check import FileOpen -from atat.core.common.utils import check_file_or_directory_path -from atat.pytorch.common.log import logger +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create +from msprobe.pytorch.common.log import logger try: @@ -73,16 +73,6 @@ class Util: def path_strip(path): return path.strip("'").strip('"') - @staticmethod - def _gen_npu_dump_convert_file_info(name, match, dir_path): - return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2), - op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3], - anchor_idx=int(match.groups()[-2])) - - @staticmethod - def _gen_numpy_file_info(name, math, dir_path): - return FileDesc(name, dir_path) - @staticmethod def check_executable_file(path): check_path_owner_consistent(path) @@ -184,6 +174,16 @@ class Util: def change_filemode_safe(self, path): change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY) + @staticmethod + def _gen_npu_dump_convert_file_info(name, match, dir_path): + return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2), + op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3], + anchor_idx=int(match.groups()[-2])) + + @staticmethod + def _gen_numpy_file_info(name, math, dir_path): + return FileDesc(name, dir_path) + def execute_command(self, cmd): if not cmd: self.log.error("Commond is None") @@ -245,7 +245,11 @@ class Util: elif data.size % align != 0: pad_array = np.zeros((align - data.size % align,)) data = np.append(data, pad_array) - np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + check_path_before_create(dst_file) + try: + np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + except Exception as e: + self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file) change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) def list_convert_files(self, path, external_pattern=""): diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/visualization.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py similarity index 94% rename from debug/accuracy_tools/atat/pytorch/parse_tool/lib/visualization.py rename to debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py index 3ef9878ae8213e5ece5813d8857544ce07b603d5..a10c7a447fd1cc5b18202bfb382190ed7a663ccc 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py @@ -17,10 +17,10 @@ import json import numpy as np -from atat.pytorch.parse_tool.lib.config import Const -from atat.pytorch.parse_tool.lib.utils import Util -from atat.pytorch.parse_tool.lib.parse_exception import ParseException -from atat.core.common.file_check import FileOpen +from msprobe.pytorch.parse_tool.lib.config import Const +from msprobe.pytorch.parse_tool.lib.utils import Util +from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException +from msprobe.core.common.file_utils import FileOpen class Visualization: diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py similarity index 58% rename from debug/accuracy_tools/atat/pytorch/pt_config.py rename to debug/accuracy_tools/msprobe/pytorch/pt_config.py index 0674b91b3410765ca7dcb9a6f38c6f03fa6e94f1..92145ee2c70fe8667726977236afecb4c14c4d3b 100644 --- a/debug/accuracy_tools/atat/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -1,9 +1,10 @@ import json import os -from atat.core.common_config import CommonConfig, BaseConfig -from atat.core.common.file_check import FileOpen -from atat.core.common.const import Const +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.const import Const +from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps class TensorConfig(BaseConfig): @@ -31,12 +32,12 @@ class StatisticsConfig(BaseConfig): class OverflowCheckConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) - self.overflow_num = json_config.get("overflow_nums") + self.overflow_nums = json_config.get("overflow_nums") self.check_mode = json_config.get("check_mode") self.check_overflow_config() def check_overflow_config(self): - if self.overflow_num is not None and not isinstance(self.overflow_num, int): + if self.overflow_nums is not None and not isinstance(self.overflow_nums, int): raise Exception("overflow_num is invalid") if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]: raise Exception("check_mode is invalid") @@ -61,20 +62,54 @@ class FreeBenchmarkCheckConfig(BaseConfig): if self.preheat_step and self.preheat_step == 0: raise Exception("preheat_step cannot be 0") + +class RunUTConfig(BaseConfig): + WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) + def __init__(self, json_config): + super().__init__(json_config) + self.white_list = json_config.get("white_list", Const.DEFAULT_LIST) + self.black_list = json_config.get("black_list", Const.DEFAULT_LIST) + self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH) + self.check_run_ut_config() + + @classmethod + def check_filter_list_config(cls, key, filter_list): + if not isinstance(filter_list, list): + raise Exception("%s must be a list type" % key) + if not all(isinstance(item, str) for item in filter_list): + raise Exception("All elements in %s must be string type" % key) + invalid_api = [item for item in filter_list if item not in cls.WrapApi] + if invalid_api: + raise Exception("Invalid api in %s: %s" % (key, invalid_api)) + + @classmethod + def check_error_data_path_config(cls, error_data_path): + if not os.path.exists(error_data_path): + raise Exception("error_data_path: %s does not exist" % error_data_path) + + def check_run_ut_config(self): + RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) + RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list) + RunUTConfig.check_error_data_path_config(self.error_data_path) + + def parse_task_config(task, json_config): default_dic = {} if task == Const.TENSOR: - config_dic = json_config.get(Const.TENSOR) if json_config.get(Const.TENSOR) else default_dic + config_dic = json_config.get(Const.TENSOR, default_dic) return TensorConfig(config_dic) elif task == Const.STATISTICS: - config_dic = json_config.get(Const.STATISTICS) if json_config.get(Const.STATISTICS) else default_dic + config_dic = json_config.get(Const.STATISTICS, default_dic) return StatisticsConfig(config_dic) elif task == Const.OVERFLOW_CHECK: - config_dic = json_config.get(Const.OVERFLOW_CHECK) if json_config.get(Const.OVERFLOW_CHECK) else default_dic + config_dic = json_config.get(Const.OVERFLOW_CHECK, default_dic) return OverflowCheckConfig(config_dic) elif task == Const.FREE_BENCHMARK: - config_dic = json_config.get(Const.FREE_BENCHMARK) if json_config.get(Const.FREE_BENCHMARK) else default_dic + config_dic = json_config.get(Const.FREE_BENCHMARK, default_dic) return FreeBenchmarkCheckConfig(config_dic) + elif task == Const.RUN_UT: + config_dic = json_config.get(Const.RUN_UT, default_dic) + return RunUTConfig(config_dic) else: return StatisticsConfig(default_dic) diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py similarity index 77% rename from debug/accuracy_tools/atat/pytorch/service.py rename to debug/accuracy_tools/msprobe/pytorch/service.py index d0b9c4d4b27bd13672e0529de6dd3349860e533f..3238d11b2b96689f0185d71d6c8f1c6a04e3f270 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -1,18 +1,23 @@ import functools import os from pathlib import Path - -from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileChecker, check_path_before_create -from atat.core.common.const import Const, FileCheckConst -from atat.core.common.exceptions import DistributedNotInitializedError, MsaccException -from atat.core.data_dump.data_collector import build_data_collector -from atat.core.data_dump.scope import BaseScope -from atat.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from atat.pytorch.common.utils import get_rank_if_initialized -from atat.pytorch.module_processer import ModuleProcesser -from atat.pytorch.hook_module import remove_dropout -from atat.pytorch.hook_module.api_registry import api_register +import torch +from packaging import version +from msprobe.core.common.const import Const, FileCheckConst +from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException +from msprobe.core.common.file_utils import FileChecker, check_path_before_create +from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from msprobe.core.data_dump.scope import BaseScope +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import get_rank_if_initialized +from msprobe.pytorch.hook_module import remove_dropout +from msprobe.pytorch.hook_module.api_registry import api_register +from msprobe.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.module_processer import ModuleProcesser + +if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook class Service: @@ -27,6 +32,11 @@ class Service: self.current_rank = None self.dump_iter_dir = None + @staticmethod + def forward_backward_dump_end(): + logger.info_on_rank_0("Data needed ends here.") + api_register.api_originality() + def build_hook(self, module_type, name): def pre_hook(api_or_module_name, module, args, kwargs): if module_type == BaseScope.Module_Type_Module: @@ -62,7 +72,8 @@ class Service: if not self.switch: return if self.data_collector: - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序 + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output) pid = os.getpid() @@ -73,15 +84,30 @@ class Service: backward_hook = functools.partial(backward_hook, backward_name_template) return pre_forward_hook, forward_hook, backward_hook + def hook_optimizer(self, model): + def optimizer_pre_step_hook(optimizer, args, kwargs): + self.stop() + self.step() + + def optimizer_post_step_hook(optimizer, args, kwargs): + self.start(model) + + + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + def step(self): self.current_iter += 1 self.data_collector.update_iter(self.current_iter) - def start(self, model): + ModuleProcesser.reset_module_stats() + HOOKModule.reset_module_stats() + + def start(self, model, api_origin=False): self.model = model if self.config.step and self.current_iter > max(self.config.step): self.stop() - raise Exception("atat: exit after iteration {}".format(max(self.config.step))) + raise Exception("msprobe: exit after iteration {}".format(max(self.config.step))) if self.config.step and self.current_iter not in self.config.step: return if self.first_start: @@ -94,6 +120,8 @@ class Service: return self.register_hook_new() self.first_start = False + if api_origin: + api_register.api_modularity() self.switch = True logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") if self.config.level != "L2": @@ -138,7 +166,7 @@ class Service: logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task)) if self.config.level in ["L0", "mix"]: if self.model is None: - logger.error_log_with_exp("The model is None.", MsaccException.INVALID_PARAM_ERROR) + logger.error_log_with_exp("The model is None.", MsprobeException.INVALID_PARAM_ERROR) logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available") for name, module in self.model.named_modules(): if module == self.model: @@ -164,4 +192,4 @@ class Service: api_register.api_modularity() if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task: - remove_dropout() + remove_dropout() \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/msit_path_add.py b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/msit_path_add.py new file mode 100644 index 0000000000000000000000000000000000000000..fe49958cd8752bd45d6c5f9c198a49ea1c0996ba --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/msit_path_add.py @@ -0,0 +1,35 @@ +import os +import site +from msprobe.core.common.file_utils import change_mode +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.log import logger +from msprobe.core.common.utils import PathAddException + + +def create_symlink(): + create_single_symlink(link_rel_path="msit_llm", source_rel_path="msprobe/msit/msit/components/llm/msit_llm") + create_single_symlink(link_rel_path="components", source_rel_path="msprobe/msit/msit/components") + logger.info("torchair compare path added successfully.") + + +def create_single_symlink(link_rel_path, source_rel_path): + site_package_dir = site.getsitepackages()[0] + link_path = os.path.join(site_package_dir, link_rel_path) + if os.path.exists(link_path) or os.path.islink(link_path): + logger.warning("The msit and msit_llm packages are already installed. If you need accuracy comparison " + "capabilities in the torchair training scenario, please ensure that versions of the packages " + "support this feature or update to the latest version.") + raise PathAddException(PathAddException.PACKAGE_VERSION_CHECK) + try: + source_path = os.path.join(site_package_dir, source_rel_path) + os.symlink(source_path, link_path) + change_mode(link_path, FileCheckConst.DATA_DIR_AUTHORITY) + except FileNotFoundError as e: + logger.error("Source folder does not exist before create symlink.") + raise PathAddException(PathAddException.INVALID_FILE_ERROR) from e + except PermissionError as e: + logger.error("Permission denied when trying to create the symlink.") + raise PathAddException(PathAddException.PERMISSION_DENIED_ERROR) from e + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + raise PathAddException(PathAddException.UNEXPECTED_ERROR) from e diff --git a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..c46e3f04456db0f63eb7c19ba94a5252f99f26b2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py @@ -0,0 +1,14 @@ +from msit_llm.compare.torchair_acc_cmp import acc_compare + + +def torchair_compare_parser(parser): + parser.add_argument("--my-path", dest="my_path", type=str, + help=" The graph compare input path.", required=True) + parser.add_argument("--golden-path", dest="golden_path", type=str, + help=" The graph compare gold input path.", required=True) + parser.add_argument("--output-path", dest="output_path", type=str, default=".", + help=" The graph compare input path.", required=False) + + +def torchair_compare_cli(args): + acc_compare(args.golden_path, args.my_path, args.output_path) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..675a87a412cf078e940441e53512156e8082476e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode +from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.utils import load_json_file, load_data_json_file, save_json_file, GraphConst +from msprobe.pytorch.visualization.builder.msprobe_adapter import get_input_output + + +class GraphBuilder: + @staticmethod + def build(construct_path, data_path, model_name='DefaultModel'): + """ + GraphBuilder的对外提供的构图方法 + Args: + construct_path: construct.json路径 + data_path: dump.json路径 + model_name: 模型名字,依赖外部输入 + Returns: Graph,代表图的数据结构 + """ + construct_dict = load_json_file(construct_path) + data_dict = load_data_json_file(data_path) + graph = Graph(model_name) + GraphBuilder._init_nodes(graph, construct_dict, data_dict) + GraphBuilder._collect_apis_between_modules(graph) + return graph + + @staticmethod + def to_json(filename, config): + """ + 将graph导出成.vis文件的接口 + """ + result = {} + if config.graph_b: + result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict() + result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict() + else: + result = config.graph_n.to_dict() + if config.tool_tip: + result[GraphConst.JSON_TIP_KEY] = config.tool_tip + if config.node_colors: + result[GraphConst.COLORS] = config.node_colors + if config.micro_steps: + result[GraphConst.MICRO_STEPS] = config.micro_steps + save_json_file(filename, result) + + @staticmethod + def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id): + """ + 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点 + """ + # 匹配以.backward.后跟一个或多个数字结尾的模式 + backward_pattern = r"(\.backward\.)(\d+)$" + forward_pattern = r"(\.forward\.)(\d+)$" + if re.search(backward_pattern, subnode_id) and not upnode_id: + forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id)) + if forward_upnode_id: + new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id) + if new_upnode_id in construct_dict: + return new_upnode_id + return upnode_id + + @staticmethod + def _init_nodes(graph, construct_dict, data_dict): + for subnode_id, upnode_id in construct_dict.items(): + upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id) + if upnode_id: + upnode_op = NodeOp.get_node_op(upnode_id) + upnode = GraphBuilder._create_or_get_node(graph, data_dict, upnode_op, upnode_id) + else: + upnode = graph.root + node_op = NodeOp.get_node_op(subnode_id) + GraphBuilder._create_or_get_node(graph, data_dict, node_op, subnode_id, upnode) + + @staticmethod + def _create_or_get_node(graph, data_dict, op, name, upnode=None): + if name in graph.node_map: + node = graph.get_node(name) + else: + graph.add_node(op, name, upnode) + node = graph.get_node(name) + node_data = data_dict.get(name, {}) + # 添加输入输出数据 + input_data, output_data = get_input_output(node_data, node.id) + # 更新数据 + node.set_input_output(input_data, output_data) + # 添加节点 + node.add_upnode(upnode) + return node + + @staticmethod + def _collect_apis_between_modules(graph): + """ + 图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点 + Args: + graph: 模型结构 + + Returns: None + """ + i = 0 + output = [] + node_list = graph.root.subnodes + while i < len(node_list): + current_node = node_list[i] + + # 当前节点为api,检查后续是否还有api + if current_node.op == NodeOp.function_api: + temp_nodes = [current_node] + i += 1 + while i < len(node_list) and node_list[i].op == NodeOp.function_api: + temp_nodes.append(node_list[i]) + i += 1 + + # 检查api节点是否大于等于2个 + if len(temp_nodes) >= 2: + # 创建新节点,将这些api节点放入新节点的subnodes属性 + node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES, + id_accumulation=True) + api_collection_node = graph.get_node(node_id) + api_collection_node.subnodes = temp_nodes + output.append(api_collection_node) + else: + # 如果连续的api节点不足2个,将它们原样添加到输出列表 + output.extend(temp_nodes) + else: + # 如果当前节点为module,直接添加到输出列表 + output.append(current_node) + i += 1 + + graph.root.subnodes = output + + +class GraphExportConfig: + def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None): + self.graph_n = graph_n + self.graph_b = graph_b + self.tool_tip = tool_tip + self.node_colors = node_colors + self.micro_steps = micro_steps diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..92af6d67325c2b2ff2b4dbd34bf799b3219445de --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from msprobe.pytorch.compare.acc_compare import read_op, merge_tensor, get_accuracy, _do_multi_process +from msprobe.core.common.utils import task_dumppath_get +from msprobe.pytorch.visualization.utils import GraphConst + + +# 用于将节点名字解析成对应的NodeOp的规则 +op_patterns = [ + r'^(Module)', #NodeOp.module + r'^(Tensor|Torch|Functional|NPU|VF|Distributed|Aten)' #NodeOp.function_api +] + + +def get_compare_mode(dump_path_param): + """ + 获得比较模式,包括summary、MD5和真实数据三种模式 + Args: + dump_path_param: 调用acc_compare接口所依赖的参数 + Returns: 0 summary mode, 1 md5 mode, 2 true data mode + """ + summary_compare, md5_compare = task_dumppath_get(dump_path_param) + if summary_compare: + compare_mode = GraphConst.SUMMARY_COMPARE + elif md5_compare: + compare_mode = GraphConst.MD5_COMPARE + else: + compare_mode = GraphConst.REAL_DATA_COMPARE + return compare_mode + + +def run_real_data(dump_path_param, csv_path): + """ + 多进程运行生成真实数据 + Args: + dump_path_param: 调用acc_compare接口所依赖的参数 + csv_path: 生成文件路径 + """ + return _do_multi_process(dump_path_param, csv_path) + + +def get_input_output(node_data, node_id): + """ + 将dump的原始数据进行拆解,分解为output和input两个数据 + Args: + node_data: 属于单个节点的dump数据 + node_id: 节点名字 + """ + input_data = {} + output_data = {} + op_parsed_list = read_op(node_data, node_id) + for item in op_parsed_list: + full_op_name = item.get('full_op_name', '') + if not full_op_name: + continue + splits = full_op_name.split('.') + if len(splits) < GraphConst.OUTPUT_MIN_LEN: + continue + if GraphConst.OUTPUT in splits[GraphConst.OUTPUT_INDEX_TWO] and \ + GraphConst.INPUT not in splits[GraphConst.OUTPUT_INDEX_THREE]: + output_data[full_op_name] = item + else: + input_data[full_op_name] = item + return input_data, output_data + + +def compare_data(data_dict_list1, data_dict_list2): + """ + 比较get_input_output中输出的结果是否结构一致,比较一致返回True + """ + if len(data_dict_list1) != len(data_dict_list2): + return False + # 用于比较两个节点是否相等的关键字段 + tag_keys = ['type', 'dtype', 'shape'] + for key1, key2 in zip(data_dict_list1, data_dict_list2): + dict1 = data_dict_list1[key1] + dict2 = data_dict_list2[key2] + for tag_key in tag_keys: + tag_value1 = dict1.get(tag_key, None) + tag_value2 = dict2.get(tag_key, None) + if tag_value1 != tag_value2: + return False + return True + + +def compare_mapping_data(data_dict_list1, data_dict_list2): + """ + node1映射node2,可能node1参数多于或少于node2参数,个别参数的shape的维度顺序不同,node1参数null对应node2参数其他值 + 工具要尽可能保证node的数据能够比对,进行数据的弱校验,仅校验参数的shape维度数值是否相同 + """ + for x, y in zip(data_dict_list1.values(), data_dict_list2.values()): + x_shape = x.get('shape') + y_shape = y.get('shape') + if x_shape is None or y_shape is None: + continue + x_shape = sorted(x_shape) if isinstance(x_shape, list) else x_shape + y_shape = sorted(y_shape) if isinstance(y_shape, list) else y_shape + if x_shape != y_shape: + return False + return True + + +def format_node_data(data_dict): + """ + 批量进行节点数据的输出 + """ + del_list = ['requires_grad', 'data_name', 'full_op_name'] + for _, value in data_dict.items(): + if not isinstance(value, dict): + continue + for item in del_list: + if item in value: + del value[item] + _format_data(value) + return data_dict + + +def compare_node(node_ids, data_dicts, stack_json_data, is_summary_compare, is_md5_compare): + """ + 调用acc_compare.py中的get_accuracy获得精度对比指标 + 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口 + Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list + """ + merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, is_summary_compare, is_md5_compare) + merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, is_summary_compare, is_md5_compare) + result = [] + get_accuracy(result, merge_n, merge_b, is_summary_compare, is_md5_compare) + return result + + +def _parse_node(node_id, data_dict, stack_json_data, is_summary_compare, is_md5_compare): + """ + 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参 + """ + op_parsed_list = read_op(data_dict.get(node_id, {}), node_id) + if node_id in stack_json_data: + op_parsed_list.append( + {'full_op_name': node_id, 'full_info': stack_json_data[node_id]}) + else: + op_parsed_list.append({'full_op_name': node_id, 'full_info': None}) + result = merge_tensor(op_parsed_list, is_summary_compare, is_md5_compare) + if not result: + result['op_name'] = [] + return result + + +def _format_decimal_string(s): + """ + 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串 + """ + pattern = re.compile(r'\d{1,20}\.\d{1,20}%?') + matches = pattern.findall(s) + for match in matches: + is_percent = match.endswith('%') + number_str = match.rstrip('%') + decimal_part = number_str.split('.')[1] + # 如果小数位数大于6,进行处理 + if len(decimal_part) > GraphConst.ROUND_TH: + number_float = float(number_str) + formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}" + # 如果原来是百分数,加回百分号 + if is_percent: + formatted_number += '%' + # 替换原字符串中的数值部分 + s = s.replace(match, formatted_number) + return s + + +def _format_data(data_dict): + """ + 格式化数据,小数保留6位,处理一些异常值 + """ + pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$' + none_num = 0 + for key, value in data_dict.items(): + if isinstance(value, str): + # 将单引号删掉,None换成null避免前端解析错误 + value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL) + value = _format_decimal_string(value) + elif value is None or value == ' ': + value = GraphConst.NULL + # 科学计数法1.123123123123e-11,格式化为1.123123e-11 + elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)): + value = "{:.6e}".format(value) + elif isinstance(value, float): + value = round(value, GraphConst.ROUND_TH) + # Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案 + if not isinstance(value, (list, tuple, dict, str)): + value = str(value) + if value == GraphConst.NULL or key == GraphConst.ERROR_KEY: + none_num += 1 + data_dict[key] = value + # 字典里的value全null,只保留一个null + if none_num == len(data_dict): + data_dict.clear() + data_dict[GraphConst.VALUE] = GraphConst.NULL diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ec6c8db4be6431352570cfeb3644a7da816f35 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data +from msprobe.pytorch.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df +from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp +from msprobe.pytorch.visualization.graph.node_colors import NodeColors +from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter + + +class GraphComparator: + def __init__(self, graphs, data_paths, stack_path, output_path, mapping_config=None): + self.graph_n = graphs[0] + self.graph_b = graphs[1] + self._parse_param(data_paths, stack_path, output_path) + self.mapping_config = mapping_config + + def compare(self): + """ + 比较函数,初始化结束后单独调用。比较结果写入graph_n + """ + self._compare_nodes(self.graph_n.root) + self._postcompare() + + def add_compare_result_to_node(self, node, compare_result_list): + """ + 将比对结果添加到节点的输入输出数据中 + Args: + node: 节点 + compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list + """ + # 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中 + if self.ma.prepare_real_data(node): + return + compare_in_dict = {} + compare_out_dict = {} + # input和output对比数据分开 + for item in compare_result_list: + if not node.stack_info and node.id in item[0]: + node.stack_info = item[-1] + if 'output' in item[0]: + compare_out_dict[item[0]] = item + else: + compare_in_dict[item[0]] = item + precision_index, other_dict = ( + self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + node.data.update(other_dict) + if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index): + self.ma.add_error_key(node.output_data) + node.get_suggestions() + + def _parse_param(self, data_paths, stack_path, output_path): + self.dump_path_param = { + 'npu_json_path': data_paths[0], + 'bench_json_path': data_paths[1], + 'stack_json_path': stack_path, + 'is_print_compare_log': True + } + self.output_path = output_path + compare_mode = get_compare_mode(self.dump_path_param) + self.ma = ModeAdapter(compare_mode) + self.data_n_dict = load_data_json_file(data_paths[0]) + self.data_b_dict = load_data_json_file(data_paths[1]) + self.stack_json_data = load_json_file(stack_path) + + def _postcompare(self): + self._handle_api_collection_index() + if not self.ma.is_real_data_compare(): + return + df = get_csv_df(self.ma.is_md5_compare(), self.ma.is_summary_compare(), True, self.ma.csv_data) + df = run_real_data(self.dump_path_param, df) + compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} + for node in self.ma.compare_nodes: + precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index): + self.ma.add_error_key(node.output_data) + node.get_suggestions() + + def _handle_api_collection_index(self): + """ + api集合的指标使用集合中所有api最小的指标 + """ + for node in self.graph_n.root.subnodes: + if node.op == NodeOp.api_collection: + precision_index = 1 + for api in node.subnodes: + precision_index = min(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, 1)) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + + def _compare_nodes(self, node_n): + """ + 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 + 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 + """ + if self.mapping_config: + node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_config) + if node_b: + ancestors_n.append(node_n.id) + ancestors_b.append(node_b.id) + node_n.matched_node_link = ancestors_b + node_b.matched_node_link = ancestors_n + else: + node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) + if node_b: + ancestors.append(node_b.id) + node_n.add_link(node_b, ancestors) + if node_b: + # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 + compare_result_list = compare_node([node_n.id, node_b.id], + [self.data_n_dict, self.data_b_dict], + self.stack_json_data, self.ma.is_summary_compare(), + self.ma.is_md5_compare()) + if compare_result_list: + self.ma.add_csv_data(compare_result_list) + self.add_compare_result_to_node(node_n, compare_result_list) + for subnode in node_n.subnodes: + self._compare_nodes(subnode) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb16c8a789fc98522cfc4e62e8563a8c3a5a8bd --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from msprobe.core.common.const import CompareConst, Const +from msprobe.pytorch.visualization.utils import ToolTip, GraphConst, str2float + + +class ModeAdapter: + def __init__(self, compare_mode): + self.compare_mode = compare_mode + self.csv_data = [] + self.compare_nodes = [] + + @staticmethod + def _add_md5_compare_data(node_data, compare_data_dict): + precision_index = GraphConst.MIN_INDEX_KEY + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + compare_data = compare_data_dict.get(key) + if compare_data: + headers = CompareConst.MD5_COMPARE_RESULT_HEADER + id_list = [headers.index(x) for x in GraphConst.MD5_INDEX_LIST] + ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list) + # md5比对是否通过 + if value.get(CompareConst.RESULT) != CompareConst.PASS: + precision_index = GraphConst.MAX_INDEX_KEY + node_data[key] = value + return precision_index + + @staticmethod + def _add_real_compare_data(node_data, compare_data_dict): + min_thousandth = float(1) + numbers = [] + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + compare_data = compare_data_dict.get(key) + if compare_data: + headers = CompareConst.COMPARE_RESULT_HEADER + id_list = [headers.index(x) for x in GraphConst.REAL_DATA_INDEX_LIST] + ModeAdapter._match_data(value, compare_data, GraphConst.REAL_DATA_INDEX_LIST, id_list) + # 获取一个节点所有的输入或输出最小的双千指标 + thousandth = value.get(CompareConst.ONE_THOUSANDTH_ERR_RATIO) + # 可能是None,可能是非数字内容str + try: + thousandth = float(thousandth) + except (ValueError, TypeError): + thousandth = None + if thousandth is not None: + numbers.append(thousandth) + node_data[key] = value + # 双千指标都是None的异常情况 + if not numbers: + min_thousandth = None + else: + min_thousandth = min(numbers + [min_thousandth]) + return min_thousandth + + @staticmethod + def _add_summary_compare_data( node_data, compare_data_dict): + max_relative_err = 0 + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + compare_data = compare_data_dict.get(key) + if compare_data: + # 对应比对结果csv的列 + key_list = GraphConst.SUMMARY_INDEX_LIST + headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + id_list = [headers.index(x) for x in key_list] + ModeAdapter._match_data(value, compare_data, key_list, id_list) + # 相对误差大于0.5疑似有精度问题,小值域1e-3不比较相对误差 + for index, item in enumerate(key_list[4:]): + value_diff = value.get(key_list[index]) + if isinstance(value_diff, float) and value_diff != 0 and abs(value_diff) < GraphConst.SMALL_VALUE: + value[item] = ToolTip.SMALL_VALUE_TIP.format(key_list[index]) + continue + relative_err = str2float(value.get(item)) + max_relative_err = max(max_relative_err, relative_err) + node_data[key] = value + max_relative_err = 1 if max_relative_err > 1 else max_relative_err + return max_relative_err + + @staticmethod + def _match_data(data_dict, compare_data, key_list, id_list): + """ + 绑定精度指标到node的input_data和output_data + """ + if len(key_list) != len(id_list): + return + for id, key in zip(id_list, key_list): + data = compare_data[id] + if data is not None and 'nan' not in str(data) and str(data) != ' ': + data_dict[key] = data + else: + data_dict[key] = 'null' + + def parse_result(self, node, compare_data_dict): + """ + 根据结果返回数据,分别是precision_index,和附加数据 + """ + other_dict = {} + if self.is_md5_compare(): + precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) + precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) + # 所有输入输出md5对比通过,这个节点才算通过 + precision_index = max(precision_index_in, precision_index_out) + other_result = CompareConst.PASS if precision_index == 1 else CompareConst.DIFF + other_dict[CompareConst.RESULT] = other_result + elif self.is_summary_compare(): + precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) + precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) + precision_index = max(precision_index_in, precision_index_out) + else: + min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0]) + min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0]) + if min_thousandth_in is not None and min_thousandth_out is not None: + change_percentage = abs(min_thousandth_in - min_thousandth_out) + else: + change_percentage = 0 + precision_index = GraphConst.MAX_INDEX_KEY \ + if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage + return precision_index, other_dict + + def prepare_real_data(self, node): + """ + 为真实数据比较模式准备节点信息 + """ + if self.is_real_data_compare(): + self.compare_nodes.append(node) + return True + return False + + def is_summary_compare(self): + return self.compare_mode == GraphConst.SUMMARY_COMPARE + + def is_md5_compare(self): + return self.compare_mode == GraphConst.MD5_COMPARE + + def is_real_data_compare(self): + return self.compare_mode == GraphConst.REAL_DATA_COMPARE + + def add_csv_data(self, compare_result_list): + if not self.is_real_data_compare(): + return + self.csv_data.extend(compare_result_list) + + def add_error_key(self, node_data): + """ + 根据不同的模式进行提供不同错误信息 + """ + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + if self.is_summary_compare(): + message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + elif self.is_real_data_compare(): + message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] + else: + # 输出件优化 + message = [] + value[GraphConst.ERROR_KEY] = message + node_data[key] = value + + def get_tool_tip(self): + """ + 用于前端展示字段的具体含义 + """ + if self.is_summary_compare(): + tips = { + CompareConst.MAX_DIFF: ToolTip.MAX_DIFF, + CompareConst.MIN_DIFF: ToolTip.MIN_DIFF, + CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF, + CompareConst.NORM_DIFF: ToolTip.NORM_DIFF} + elif self.is_md5_compare(): + tips = {Const.MD5: ToolTip.MD5} + else: + tips = { + CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO, + CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.COSINE: ToolTip.COSINE, + CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, + CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR} + return json.dumps(tips) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..87f020a956b24f4b5884e3e60b72513f6f39625f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.utils import Suggestions, GraphConst +from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_mapping_data + + +class BaseNode: + def __init__(self, node_op, node_id, up_node=None): + self.op = node_op + self.id = node_id + self.data = {} + self.output_data = {} + self.input_data = {} + self.upnode = None + self.add_upnode(up_node) + self.subnodes = [] + self.matched_node_link = [] + self.suggestions = {} + self.stack_info = [] + self.micro_step_id = None + + def __str__(self): + info = f'id:\t{self.id}' + return info + + def __eq__(self, other): + """ + 用来判断两个节点是否可以被匹配上,认为结构上是否一致 + """ + if not compare_data(self.input_data, other.input_data): + return False + if not compare_data(self.output_data, other.output_data): + return False + return True + + def compare_mapping_node(self, other): + if not compare_mapping_data(self.input_data, other.input_data): + return False + if not compare_mapping_data(self.output_data, other.output_data): + return False + return True + + def get_suggestions(self): + """ + 精度疑似有问题时,提供一些建议 + """ + if self.op == NodeOp.module: + self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module + self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL + elif self.op == NodeOp.function_api: + self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API + self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL + + def set_input_output(self, input_data, output_data): + self.input_data = input_data + self.output_data = output_data + + def add_upnode(self, node): + """ + 绑定upnode,用于对两个节点进行上下级关联 + """ + if not node or node.id == self.id or self.upnode: + return + self.upnode = node + node.subnodes.append(self) + + def add_link(self, node, ancestors): + """ + 在节点匹配成功后进行匹配数据的录入 + Args: + node: 和self相互匹配的节点 + ancestors: 对面节点的祖先信息 + """ + self.matched_node_link = ancestors + node.matched_node_link = ancestors + + def to_dict(self): + """ + 输出数据 + """ + result = {} + result['id'] = self.id + result['node_type'] = self.op.value + result['data'] = self.data + result['output_data'] = format_node_data(self.output_data) + result['input_data'] = format_node_data(self.input_data) + result['upnode'] = self.upnode.id if self.upnode else 'None' + result['subnodes'] = [node.id for node in self.subnodes] + result['matched_node_link'] = self.matched_node_link + result['suggestions'] = self.suggestions + result['stack_info'] = self.stack_info + if self.micro_step_id is not None: + result['micro_step_id'] = self.micro_step_id + return result + + def get_ancestors(self): + """ + 获取节点所有祖先的列表 + """ + ancestors = [] + current_node = self.upnode + while current_node: + ancestors.append(current_node.id) + current_node = current_node.upnode + return list(reversed(ancestors)) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca5c6811b496808c521f8bfd5df58adb4605c1e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.graph.base_node import BaseNode +from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.utils import GraphConst +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const + + +class Graph: + def __init__(self, model_name): + self.node_map = {} + self.node_id_map = {} + self.add_node(NodeOp.module, model_name) + self.root = self.get_node(model_name) + + def __str__(self): + infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] + info = "\n".join(infos) + return info + + @staticmethod + def match(graph_n, node_n, graph_b): + """ + 给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成 + 目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑 + 返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, [] + """ + if not node_n or node_n.id not in graph_b.node_map: + return None, [] + node_b = graph_b.node_map.get(node_n.id) + if node_n != node_b: + return None, [] + ancestors_n = node_n.get_ancestors() + ancestors_b = node_b.get_ancestors() + if ancestors_n != ancestors_b: + return None, [] + return node_b, ancestors_n + + @staticmethod + def mapping_match(node_n, graph_b, mapping_config): + """ + 根据映射配置对节点进行匹配 + """ + node_b = graph_b.node_map.get(mapping_config.get_mapping_string(node_n.id)) + if not node_b or not node_n.compare_mapping_node(node_b): + return None, [], [] + ancestors_n = node_n.get_ancestors() + ancestors_b = node_b.get_ancestors() + return node_b, ancestors_n, ancestors_b + + @staticmethod + def dfs(node, result): + info = node.to_dict() + result[node.id] = info + for subnode in node.subnodes: + Graph.dfs(subnode, result) + + @staticmethod + def split_nodes_by_micro_step(nodes): + """ + 根据Module名称后缀数字, 区分一个step中的多个micro steps, 后缀数字相同代表节点属于同一个micro step. + 如果是非Module节点,分类到前一个Module节点所在的micro step. + """ + result = {} + default_id = 0 + result[default_id] = [] + + for node in nodes: + if node.op == NodeOp.module: + micro_step_id = node.id.split(Const.SEP)[-1] + try: + micro_step_id = int(micro_step_id) + except ValueError: + logger.warning(f'The node id suffix {micro_step_id} is not a number, micro steps cannot be split.') + micro_step_id = 0 + if micro_step_id not in result: + default_id = micro_step_id + result[micro_step_id] = [] + result[micro_step_id].append(node) + else: + result[default_id].append(node) + return result + + def add_node(self, node_op, node_id, up_node=None, id_accumulation=False): + """ + 在graph中进行节点的添加 + Args: + node_op: 需要添加的节点类型 + node_id: 需要添加的节点id + up_node:对应节点的父节点 + id_accumulation: 是否对传入的重复node_id进行累加 + """ + if node_id in self.node_map: + if id_accumulation: + self.node_id_map[node_id] = 0 + else: + return node_id + if id_accumulation: + if node_id in self.node_id_map: + self.node_id_map[node_id] += 1 + else: + self.node_id_map[node_id] = 0 + node_id = f'{node_id}.{self.node_id_map[node_id]}' + node = BaseNode(node_op, node_id, up_node) + self.node_map[node_id] = node + return node_id + + def get_node(self, node_id): + """ + 返回节点,不存在返回None + """ + return self.node_map.get(node_id, None) + + def to_dict(self): + """ + 用于数据输出 + """ + result = {} + result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None' + result[GraphConst.JSON_NODE_KEY] = {} + for node_id in self.node_map: + info = self.node_map.get(node_id).to_dict() + result[GraphConst.JSON_NODE_KEY][node_id] = info + return result + + def paging_by_micro_step(self, graph_other=None): + """ + 给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理 + 比对场景中,同步更新另一个图graph_other中相应节点的micro step信息 + Args: + self: 当前graph + graph_other: 可选参数,另一个graph + Returns: 分批的数量 + """ + batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes) + for batch_number, nodes in batches_n.items(): + for node in nodes: + node.micro_step_id = batch_number + # 在graph_other中更新已匹配节点的micro_step_id + if graph_other and node.matched_node_link: + node_other = graph_other.get_node(node.matched_node_link[-1]) + if node_other: + node_other.micro_step_id = batch_number + # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id + if graph_other: + for node in graph_other.root.subnodes: + if node.micro_step_id is None: + try: + micro_step_id = int(node.id.split(Const.SEP)[-1]) + except ValueError: + micro_step_id = 0 + node.micro_step_id = micro_step_id + return len(batches_n) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..8e07625140d284afac60ef1f1cf3b40e3a6f4ea7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from msprobe.pytorch.visualization.utils import GraphConst, ToolTip + +SUMMARY_DESCRIPTION = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|" +REAL_DATA_DESCRIPTION = (f"此节点所有输入的最小双千分之一和所有输出的最小双千分之一的差值的绝对值, 代表双千指标的变化情况, " + f"值越大代表测量值与标杆值的偏差越大, 双千分之一指标计算方式:{ToolTip.ONE_THOUSANDTH_ERR_RATIO}") +MD5_DESCRIPTION_N = "与标杆相比, 此节点任意输入输出的md5值不同" +MD5_DESCRIPTION_Y = "与标杆相比, 此节点所有输入输出的md5值相同" +NOT_MATCHED = "比对过程中节点未匹配上" + + +class NodeColors(Enum): + # 枚举值后缀数字越小, 颜色越浅 + # value值左闭右开, 两个值相同代表固定值 + YELLOW_1 = ("#FFFCF3", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0, 0.2], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0, 0.05], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}, + GraphConst.MD5_COMPARE: {GraphConst.VALUE: [1, 1], GraphConst.DESCRIPTION: MD5_DESCRIPTION_Y}, + }) + YELLOW_2 = ("#FFEDBE", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.2, 0.4], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.05, 0.1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + ORANGE_1 = ("#FFDC7F", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.4, 0.6], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.1, 0.15], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + ORANGE_2 = ("#FFC62E", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.6, 0.8], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.15, 0.2], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + RED = ("#E32020", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.8, 1], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.2, 1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}, + GraphConst.MD5_COMPARE: {GraphConst.VALUE: [0, 0], GraphConst.DESCRIPTION: MD5_DESCRIPTION_N}, + }) + GREY = ("#C7C7C7", { + GraphConst.VALUE: [], GraphConst.DESCRIPTION: NOT_MATCHED + }) + + def __init__(self, hex_value, mode_info): + self.hex_value = hex_value + self.mode_info = mode_info + + @staticmethod + def get_node_colors(mode): + """ + 获取不同比对模式下的颜色说明 + Args: + mode: 比对模式 + Returns: 颜色说明 + """ + return { + color.hex_value: color.get_info_by_mode(mode) for color in NodeColors if color.get_info_by_mode(mode) + } + + @staticmethod + def get_node_error_status(mode, value): + """ + 判断精度数据比对指标是否大于基准值 + Args: + mode: 比对模式 + value: 精度数据比对指标 + Returns: bool + """ + info = NodeColors.ORANGE_1.get_info_by_mode(mode) + if info and GraphConst.VALUE in info: + value_range = info[GraphConst.VALUE] + return value > value_range[0] + return False + + def get_info_by_mode(self, mode): + if isinstance(self.mode_info, dict): + # 检查是否是模式特定的信息 + if isinstance(next(iter(self.mode_info.values())), dict): + return self.mode_info.get(mode, {}) + else: + # 所有模式共享相同的信息 + return self.mode_info + return {} diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9be4923c5d6913d3e4885fca34227f2db75e28d6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +import re +from msprobe.pytorch.visualization.builder.msprobe_adapter import op_patterns + + +class NodeOp(Enum): + module = 0 + function_api = 1 + api_collection = 9 + + @staticmethod + def get_node_op(node_name: str): + """ + 基于代表节点的字符串,解析节点种类 + """ + for op in NodeOp: + index = op.value + if index < 0 or index >= len(op_patterns): + raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match") + pattern = op_patterns[index] + if re.match(pattern, node_name): + return op + raise Exception(f"Cannot parse node_name {node_name} into NodeOp") diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab966a3f33214095cbb6e01872a0878e63b59b7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator +from msprobe.pytorch.visualization.utils import GraphConst +from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig +from msprobe.core.common.log import logger +from msprobe.pytorch.visualization.mapping_config import MappingConfig +from msprobe.pytorch.visualization.graph.node_colors import NodeColors + +current_time = time.strftime("%Y%m%d%H%M%S") + + +def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model', mapping_file=None): + logger.info('Start building model graphs...') + # 对两个数据进行构图 + construct_path_n = os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE) + construct_path_b = os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE) + data_path_n = os.path.join(dump_path_n, GraphConst.DUMP_FILE) + data_path_b = os.path.join(dump_path_b, GraphConst.DUMP_FILE) + graph_n = GraphBuilder.build(construct_path_n, data_path_n, model_name) + graph_b = GraphBuilder.build(construct_path_b, data_path_b, model_name) + logger.info('Model graphs built successfully, start Comparing graphs...') + # 基于graph、stack和data进行比较 + stack_path = os.path.join(dump_path_n, GraphConst.STACK_FILE) + graph_comparator = GraphComparator([graph_n, graph_b], [data_path_n, data_path_b], stack_path, out_path, + mapping_config=MappingConfig(mapping_file) if mapping_file else None) + graph_comparator.compare() + micro_steps = graph_n.paging_by_micro_step(graph_b) + output_path = os.path.join(out_path, f'compare_{current_time}.vis') + export_config = GraphExportConfig(graph_n, graph_b, graph_comparator.ma.get_tool_tip(), + NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps) + GraphBuilder.to_json(output_path, export_config) + logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}') + + +def build_graph(dump_path, out_path, model_name='Model'): + logger.info('Start building model graph...') + construct_path = os.path.join(dump_path, GraphConst.CONSTRUCT_FILE) + data_path = os.path.join(dump_path, GraphConst.DUMP_FILE) + output_path = os.path.join(out_path, f'build_{current_time}.vis') + graph = GraphBuilder.build(construct_path, data_path, model_name) + micro_steps = graph.paging_by_micro_step() + GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps)) + logger.info(f'Model graph built successfully, the result file is saved in {output_path}') diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d986493513a345483bd6c32f808cca002b3cba5b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py @@ -0,0 +1,77 @@ +import re +import yaml +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.const import Const +from msprobe.pytorch.visualization.utils import GraphConst + + +class MappingConfig: + MAX_STRING_LEN = 10000 + + def __init__(self, yaml_file): + with FileOpen(yaml_file, 'r') as file: + config = yaml.safe_load(file) + try: + self.config = {key: self.validate(key, value) for data in config for key, value in data.items()} + except Exception as e: + raise RuntimeError("Line of yaml contains content that is not '- key: value'.") from e + self.classify_config = self._classify_and_sort_keys() + + @staticmethod + def validate(key, value): + if not isinstance(key, str): + raise ValueError(f"{key} must be a string.") + if not isinstance(value, str): + raise ValueError(f"{value} must be a string.") + return value + + @staticmethod + def convert_to_regex(s): + """ + 字符串转换为正则表达式, {}替换为d+以匹配一个或多个数字, 开始和结束添加.*以匹配任意前缀和后缀 + Args: + s: 字符串 + Returns: 正则表达式 + """ + escaped_pattern = re.escape(s) + pattern = re.sub(r'\\\{\\\}', r'\\d+', escaped_pattern) + pattern = f'.*{pattern}.*' + return pattern + + @staticmethod + def _replace_parts(origin_string, mapping_key, mapping_value): + if GraphConst.BRACE in mapping_key: + parts = mapping_key.split(GraphConst.BRACE) + m_parts = mapping_value.split(GraphConst.BRACE) + return origin_string.replace(parts[0], m_parts[0]).replace(parts[1], m_parts[1]) + else: + return origin_string.replace(mapping_key, mapping_value) + + def get_mapping_string(self, origin_string: str): + if len(origin_string) > MappingConfig.MAX_STRING_LEN: + return origin_string + for category, items in self.classify_config.items(): + if category in origin_string: + for key, value in items: + if re.match(MappingConfig.convert_to_regex(key), origin_string): + return MappingConfig._replace_parts(origin_string, key, value) + return origin_string + + def _classify_and_sort_keys(self): + categorized_dict = {} + for key, value in self.config.items(): + parts = key.split(Const.SEP) + # 获取第一个部分作为新的分类key + category_key = parts[0] + + if category_key not in categorized_dict: + categorized_dict[category_key] = [] + + # 将原始的key-value对添加到对应的分类中 + categorized_dict[category_key].append((key, value)) + + # 对每个分类中的项按key中的.数量进行排序, .数量越多排越靠前, 优先匹配 + for category in categorized_dict: + categorized_dict[category].sort(key=lambda x: -x[0].count(Const.SEP)) + + return categorized_dict diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a81f4f20922bb9fa58ce1d1e358ed1e0ce12c74 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.compare.acc_compare import result_to_csv + + +def load_json_file(file_path): + """ + 加载json文件 + """ + try: + with FileOpen(file_path, 'r') as f: + file_dict = json.load(f) + if not isinstance(file_dict, dict): + return {} + return file_dict + except json.JSONDecodeError: + return {} + + +def load_data_json_file(file_path): + """ + 加载dump.json中的data字段 + """ + return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) + + +def save_json_file(file_path, data): + """ + 保存json文件 + """ + with FileOpen(file_path, 'w') as f: + f.write(json.dumps(data, indent=4)) + + +def get_csv_df(md5_compare, summary_compare, stack, csv_data): + """ + 调用acc接口写入csv + """ + return result_to_csv(md5_compare, summary_compare, stack, csv_data) + + +def str2float(percentage_str): + """ + 百分比字符串转换转换为浮点型 + Args: + percentage_str: '0.00%', '23.4%' + Returns: float 0.00, 0.234 + """ + try: + percentage_str = percentage_str.strip('%') + return float(percentage_str) / 100 + except ValueError: + return 0 + + +class ToolTip: + MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值' + MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值' + MEAN_DIFF = 'NPU与标杆API统计信息比对,平均值的差值' + NORM_DIFF = 'NPU与标杆API统计信息比对,2范数(平方根)的差值' + MD5 = '数据MD5信息,用于比较两个数据信息是否完全一致' + ONE_THOUSANDTH_ERR_RATIO = 'Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好' + FIVE_THOUSANDTHS_ERR_RATIO = 'Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之五的比例占总元素个数的比例,比例越接近1越好' + COSINE = '通过计算两个向量的余弦值来判断其相似度,数值越接近于1说明计算出的两个张量越相似,实际可接受阈值为大于0.99。在计算中可能会存在nan,主要由于可能会出现其中一个向量为0' + MAX_ABS_ERR = '当最大绝对误差越接近0表示其计算的误差越小,实际可接受阈值为小于0.001' + MAX_RELATIVE_ERR = '当最大相对误差越接近0表示其计算的误差越小。当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象' + SMALL_VALUE_TIP = '{} 小于1e-3,不计算相对误差' + + +class Suggestions: + Module = '此模块精度比对结果疑似异常,请使用msprobe工具的数据采集功能对模块中的api进行dump比对' + API = '此api精度比对结果疑似异常,请使用msprobe工具的预检功能对api进行精度检测' + DUMP = 'msprobe工具的数据采集功能' + DUMP_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/dump.md' + API_ACCURACY_CHECKER = 'msprobe工具的预检功能' + API_ACCURACY_CHECKER_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md' + + +class GraphConst: + CONSTRUCT_FILE = 'construct.json' + DUMP_FILE = 'dump.json' + STACK_FILE = 'stack.json' + GRAPH_FILE = 'graph.vis' + ERROR_KEY = 'error_key' + SUMMARY_COMPARE = 0 + MD5_COMPARE = 1 + REAL_DATA_COMPARE = 2 + JSON_NPU_KEY = 'NPU' + JSON_BENCH_KEY = 'Bench' + JSON_TIP_KEY = 'ToolTip' + JSON_ROOT_KEY = 'root' + JSON_NODE_KEY = 'node' + DATA_KEY = 'data' + REAL_DATA_TH = 0.1 + MAX_RELATIVE_ERR_TH = 0.5 + ROUND_TH = 6 + JSON_INDEX_KEY = 'precision_index' + MAX_INDEX_KEY = 1 + MIN_INDEX_KEY = 0 + SUGGEST_KEY = 'text' + TAG_NA = 'na' + OUTPUT_INDEX_TWO = -2 + OUTPUT_INDEX_THREE = -3 + OUTPUT_MIN_LEN = 3 + INPUT = 'input' + OUTPUT = 'output' + STR_MAX_LEN = 50 + SMALL_VALUE = 1e-3 + MD5_INDEX_LIST = [CompareConst.RESULT] + REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] + SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, + CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + APIS_BETWEEN_MODULES = 'Apis_Between_Modules' + NULL = 'null' + NONE = 'None' + VALUE = 'value' + BRACE = '{}' + DESCRIPTION = 'description' + COLORS = 'Colors' + MICRO_STEPS = 'MicroSteps' diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py similarity index 90% rename from debug/accuracy_tools/atat/test/core_ut/test_utils.py rename to debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index b3273358e43593e14f931a5675e9039c684d1c59..27ffe411586707e384a404edfa39b6ac5a013464 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -20,9 +20,9 @@ import uuid from unittest import TestCase from unittest.mock import patch, MagicMock, mock_open -from atat.core.common.log import logger -from atat.core.common.const import Const -from atat.core.common.utils import (CompareException, +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.core.common.utils import (CompareException, check_seed_all, check_inplace_op, make_dump_path_if_not_exists, @@ -41,7 +41,7 @@ from atat.core.common.utils import (CompareException, check_regex_prefix_format_valid, get_dump_data_path, task_dumppath_get) -from atat.core.common.file_check import FileCheckConst +from msprobe.core.common.file_utils import FileCheckConst class TestUtils(TestCase): @@ -88,7 +88,7 @@ class TestUtils(TestCase): raise OSError if not os.path.exists(dirname): - with patch("atat.core.common.utils.Path.mkdir", new=test_mkdir): + with patch("msprobe.core.common.utils.Path.mkdir", new=test_mkdir): with self.assertRaises(CompareException) as context: make_dump_path_if_not_exists(dirname) self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) @@ -171,7 +171,7 @@ class TestUtils(TestCase): file_path = os.path.realpath(__file__) dirname = os.path.dirname(file_path) - with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker): check_file_or_directory_path(file_path, isdir=False) self.assertTrue(TestFileChecker.checked) self.assertEqual(TestFileChecker.file_path, file_path) @@ -179,7 +179,7 @@ class TestUtils(TestCase): self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE) TestFileChecker.checked = False - with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + with patch("msprobe.core.common.utils.FileChecker", new=TestFileChecker): check_file_or_directory_path(dirname, isdir=True) self.assertTrue(TestFileChecker.checked) self.assertEqual(TestFileChecker.file_path, dirname) @@ -216,9 +216,9 @@ class TestUtils(TestCase): mock_check_file_or_directory_path = MagicMock() mock_check_json_file = MagicMock() - with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ - patch("atat.core.common.utils.check_json_file", new=mock_check_json_file), \ - patch("atat.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path): + with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("msprobe.core.common.utils.check_json_file", new=mock_check_json_file), \ + patch("msprobe.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path): check_compare_param(params, "output_path") check_compare_param(params, "output_path", summary_compare=False, md5_compare=True) for i in range(len(call_args)): @@ -261,7 +261,7 @@ class TestUtils(TestCase): _check_json(handler, "test.json") self.assertEqual(handler.string, "0_0") - @patch("atat.core.common.utils._check_json") + @patch("msprobe.core.common.utils._check_json") def test_check_json_file(self, _mock_check_json): input_param = { "npu_json_path": "npu_json_path", @@ -275,7 +275,7 @@ class TestUtils(TestCase): @patch.object(logger, "error") def test_check_file_size(self, mock_error): - with patch("atat.core.common.utils.os.path.getsize", return_value=120): + with patch("msprobe.core.common.utils.os.path.getsize", return_value=120): with self.assertRaises(CompareException) as context: check_file_size("input_file", 100) self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) @@ -294,7 +294,7 @@ class TestUtils(TestCase): self.assertEqual(str(context.exception), f"prefix contains invalid characters, " f"prefix pattern {Const.REGEX_PREFIX_PATTERN}") - @patch("atat.core.common.utils.check_file_or_directory_path") + @patch("msprobe.core.common.utils.check_file_or_directory_path") def test_get_dump_data_path(self, mock_check_file_or_directory_path): file_path = os.path.realpath(__file__) dirname = os.path.dirname(file_path) @@ -322,23 +322,23 @@ class TestUtils(TestCase): mock_error.assert_called_with("Please check the json path is valid.") input_param["npu_json_path"] = "npu_json_path" - with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ - patch("atat.core.common.utils.json.load", return_value=npu_json): + with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("msprobe.core.common.utils.json.load", return_value=npu_json): summary_compare, md5_compare = task_dumppath_get(input_param) self.assertFalse(summary_compare) self.assertFalse(md5_compare) npu_json["task"] = Const.STATISTICS - with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ - patch("atat.core.common.utils.json.load", return_value=npu_json), \ - patch("atat.core.common.utils.md5_find", return_value=True): + with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("msprobe.core.common.utils.json.load", return_value=npu_json), \ + patch("msprobe.core.common.utils.md5_find", return_value=True): summary_compare, md5_compare = task_dumppath_get(input_param) self.assertFalse(summary_compare) self.assertTrue(md5_compare) npu_json["task"] = Const.OVERFLOW_CHECK - with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ - patch("atat.core.common.utils.json.load", return_value=npu_json): + with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("msprobe.core.common.utils.json.load", return_value=npu_json): with self.assertRaises(CompareException) as context: task_dumppath_get(input_param) self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..eedbe5be7e0360d7874439357419510cbde73b71 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -0,0 +1,47 @@ +import unittest +from unittest.mock import patch, mock_open, MagicMock + +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.data_collector import DataCollector +from msprobe.pytorch.debugger.debugger_config import DebuggerConfig +from msprobe.pytorch.pt_config import parse_json_config + + +class TestDataCollector(unittest.TestCase): + def setUp(self): + mock_json_data = { + "dump_path": "./ut_dump", + } + with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data): + common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) + config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1") + self.data_collector = DataCollector(config) + + def test_update_data(self): + self.data_collector.config.task = Const.OVERFLOW_CHECK + self.data_collector.data_processor.has_overflow = True + with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None): + result1 = self.data_collector.update_data("test message", "test1:") + self.assertEqual(result1, "test1:Overflow detected.") + + self.data_collector.data_processor.has_overflow = False + result2 = self.data_collector.update_data("test message", "test2:") + self.assertEqual(result2, "test2:No Overflow, OK.") + + self.data_collector.config.task = Const.STATISTICS + self.data_collector.data_processor.has_overflow = True + with patch("msprobe.core.data_dump.json_writer.DataWriter.update_data", return_value=None): + result3 = self.data_collector.update_data("test message", "test3") + self.assertEqual(result3, "test3") + + def test_pre_forward_data_collect(self): + self.data_collector.check_scope_and_pid = MagicMock(return_value=False) + self.data_collector.is_inplace = MagicMock(return_value=False) + self.data_collector.data_processor.analyze_pre_forward = MagicMock() + name = "TestModule.forward" + pid = 123 + + self.data_collector.pre_forward_data_collect(name, None, pid, None) + self.data_collector.check_scope_and_pid.assert_called_once_with( + self.data_collector.scope, "TestModule.backward", 123) diff --git a/debug/accuracy_tools/atat/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py similarity index 97% rename from debug/accuracy_tools/atat/test/core_ut/data_dump/test_json_writer.py rename to debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py index 867da001e610c7e006a9b7d4a027ecf4d6f43a69..4161377a6067c53fa4735a13641c5609bd0d06c9 100644 --- a/debug/accuracy_tools/atat/test/core_ut/data_dump/test_json_writer.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py @@ -1,10 +1,10 @@ import unittest -from atat.core.data_dump.json_writer import DataWriter +from msprobe.core.data_dump.json_writer import DataWriter import os import csv -from atat.core.common.file_check import FileOpen -from atat.core.common import utils +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common import utils from pathlib import Path import json diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_scope.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_scope.py new file mode 100644 index 0000000000000000000000000000000000000000..1989fd0a95a5894b012aa916cdf44c625de27b1a --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_scope.py @@ -0,0 +1,151 @@ +import unittest +from unittest.mock import MagicMock + +from msprobe.core.common.exceptions import ScopeException +from msprobe.core.data_dump.scope import ( + build_scope, + build_range_scope_according_to_scope_name, + BaseScope, + ListScope, + RangeScope, + APIRangeScope, + ModuleRangeScope +) + + +class TestBuildScope(unittest.TestCase): + def test_build_scope(self): + scope_class = MagicMock() + result1 = build_scope(scope_class, None, None) + self.assertEqual(result1, None) + + api_list = ['api1', 'api2'] + result2 = build_scope(scope_class, None, api_list) + self.assertEqual(result2, scope_class.return_value) + + def test_build_range_scope_according_to_scope_name(self): + result = build_range_scope_according_to_scope_name([], []) + self.assertIsInstance(result, APIRangeScope) + + +class TestBaseScope(unittest.TestCase): + def test_rectify_args(self): + scope = [] + api_list = "invalid_api_list" + with self.assertRaises(ScopeException) as context: + BaseScope.rectify_args(scope, api_list) + self.assertEqual(context.exception.code, ScopeException.InvalidApiStr) + + api_list = [1, 2, 3] + with self.assertRaises(ScopeException) as context: + BaseScope.rectify_args(scope, api_list) + self.assertEqual(context.exception.code, ScopeException.InvalidApiStr) + + scope = "module1" + api_list = [] + + expected_scope = ["module1"] + expected_api_list = [] + result_scope, result_api_list = BaseScope.rectify_args(scope, api_list) + self.assertEqual(result_scope, expected_scope) + self.assertEqual(result_api_list, expected_api_list) + + scope = 123 + api_list = [] + with self.assertRaises(ScopeException) as context: + BaseScope.rectify_args(scope, api_list) + self.assertEqual(context.exception.code, ScopeException.InvalidScope) + + scope = ["module1", 2, "module3"] + api_list = [] + with self.assertRaises(ScopeException) as context: + BaseScope.rectify_args(scope, api_list) + self.assertEqual(context.exception.code, ScopeException.InvalidScope) + + +class TestListScope(unittest.TestCase): + def test_rectify_args(self): + scope = ["module1"] + api_list = ["api1"] + with self.assertRaises(ScopeException) as context: + ListScope.rectify_args(scope, api_list) + self.assertEqual(context.exception.code, ScopeException.ArgConflict) + + def test_check(self): + list_scope = ListScope([], []) + module_name = "module1" + result = list_scope.check(module_name) + self.assertTrue(result) + + list_scope = ListScope(["module1"], []) + module_name = "module1" + result = list_scope.check(module_name) + self.assertTrue(result) + + list_scope = ListScope(["module1"], []) + module_name = "module2" + result = list_scope.check(module_name) + self.assertFalse(result) + + +class TestRangeScope(unittest.TestCase): + def test_rectify_args(self): + scope = ["module1", "module2", "module3"] + with self.assertRaises(ScopeException) as context: + RangeScope.rectify_args(scope, []) + self.assertEqual(context.exception.code, ScopeException.InvalidScope) + + scope = ["module1"] + expected_scope = ["module1", "module1"] + result_scope, result_api_list = RangeScope.rectify_args(scope, []) + self.assertEqual(result_scope, expected_scope) + + +class TestAPIRangeScope(unittest.TestCase): + def test_check_scope_is_valid(self): + api_range_scope = APIRangeScope([], []) + result = api_range_scope.check_scope_is_valid() + self.assertTrue(result) + + def test_check(self): + api_range_scope = APIRangeScope([], []) + api_name = "api1" + result = api_range_scope.check(api_name) + self.assertTrue(result) + + +class TestModuleRangeScope(unittest.TestCase): + def test_check_scope_is_valid(self): + module_range_scope = ModuleRangeScope([], []) + result = module_range_scope.check_scope_is_valid() + self.assertTrue(result) + + def test_begin_module(self): + module_range_scope = ModuleRangeScope(["module1", "module2"], []) + module_name = "module1" + module_range_scope.begin_module(module_name) + self.assertTrue(module_range_scope.in_scope) + + module_range_scope = ModuleRangeScope(["module1", "module2"], []) + module_name = "module3" + module_range_scope.begin_module(module_name) + self.assertFalse(module_range_scope.in_scope) + + def test_end_module(self): + module_range_scope = ModuleRangeScope(["module1", "module2"], []) + module_name = "module2" + module_range_scope.in_scope = True + module_range_scope.end_module(module_name) + self.assertFalse(module_range_scope.in_scope) + + module_range_scope = ModuleRangeScope(["module1", "module2"], []) + module_name = "module3" + module_range_scope.in_scope = True + module_range_scope.end_module(module_name) + self.assertTrue(module_range_scope.in_scope) + + def test_check(self): + module_range_scope = ModuleRangeScope([], []) + module_name = "module1" + result = module_range_scope.check(module_name) + self.assertTrue(result) diff --git a/debug/accuracy_tools/atat/test/core_ut/test_common_config.py b/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py similarity index 83% rename from debug/accuracy_tools/atat/test/core_ut/test_common_config.py rename to debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py index 00b17e1f1cfce0d28d4bd5c3a2130a2d89eb7409..8b2138a485b7ddac32b31ff3784df6299276b4ba 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_common_config.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py @@ -17,10 +17,10 @@ from unittest import TestCase from unittest.mock import patch -from atat.core.common.log import logger -from atat.core.common.const import Const -from atat.core.common.exceptions import MsaccException -from atat.core.common_config import CommonConfig, BaseConfig +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common_config import CommonConfig, BaseConfig class TestCommonConfig(TestCase): @@ -44,7 +44,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "task is invalid, it should be one of {}".format(Const.TASK_LIST)) self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"task": Const.TENSOR}) json_config.update({"rank": 0}) @@ -52,7 +52,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "rank is invalid, it should be a list") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"task": Const.TENSOR}) json_config.update({"rank": [0]}) @@ -61,7 +61,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "step is invalid, it should be a list") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"task": Const.TENSOR}) json_config.update({"rank": [0]}) @@ -71,7 +71,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "level is invalid, it should be one of {}".format(Const.LEVEL_LIST)) self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"task": Const.TENSOR}) json_config.update({"rank": [0]}) @@ -82,7 +82,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "seed is invalid, it should be an integer") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"task": Const.TENSOR}) json_config.update({"rank": [0]}) @@ -94,7 +94,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "is_deterministic is invalid, it should be a boolean") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"task": Const.TENSOR}) json_config.update({"rank": [0]}) @@ -107,7 +107,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "enable_dataloader is invalid, it should be a boolean") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) @patch.object(logger, "error_log_with_exp") def test_base_config(self, mock_error_log_with_exp): @@ -121,7 +121,7 @@ class TestCommonConfig(TestCase): self.assertIsNone(base_config.backward_input) self.assertIsNone(base_config.file_format) self.assertIsNone(base_config.summary_mode) - self.assertIsNone(base_config.overflow_num) + self.assertIsNone(base_config.overflow_nums) self.assertIsNone(base_config.check_mode) json_config.update({"scope": "Tensor_Add"}) @@ -130,7 +130,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "scope is invalid, it should be a list") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"scope": ["Tensor_Add"]}) json_config.update({"list": "Tensor_Add"}) @@ -139,7 +139,7 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "list is invalid, it should be a list") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) json_config.update({"scope": ["Tensor_Add"]}) json_config.update({"list": ["Tensor_Add"]}) @@ -149,4 +149,4 @@ class TestCommonConfig(TestCase): self.assertEqual(mock_error_log_with_exp.call_args[0][0], "data_mode is invalid, it should be a list") self.assertEqual(str(mock_error_log_with_exp.call_args[0][1]), - MsaccException.err_strs.get(MsaccException.INVALID_PARAM_ERROR)) + MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)) diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py similarity index 85% rename from debug/accuracy_tools/atat/test/core_ut/test_file_check.py rename to debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py index aa7882aa5906b8e472495ede42653f9f5584f573..c3f7836bf177c283b246e061d2df61f682175c38 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py @@ -19,10 +19,10 @@ import os from unittest import TestCase from unittest.mock import patch, MagicMock -from atat.core.common.log import logger -from atat.core.common.const import FileCheckConst -from atat.core.common.exceptions import FileCheckException -from atat.core.common.file_check import (check_link, +from msprobe.core.common.log import logger +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.common.file_utils import (check_link, check_path_length, check_path_exists, check_path_readability, @@ -40,7 +40,7 @@ from atat.core.common.file_check import (check_link, class TestFileCheckUtil(TestCase): @patch.object(logger, "error") def test_check_link(self, mock_logger_error): - with patch("atat.core.common.file_check.os.path.islink", return_value=True): + with patch("msprobe.core.common.file_utils.os.path.islink", return_value=True): with self.assertRaises(FileCheckException) as context: check_link("link_path") self.assertEqual(str(context.exception), @@ -72,7 +72,7 @@ class TestFileCheckUtil(TestCase): @patch.object(logger, "error") def test_check_path_exists(self, mock_logger_error): - with patch("atat.core.common.file_check.os.path.exists", return_value=False): + with patch("msprobe.core.common.file_utils.os.path.exists", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_exists("file_path") self.assertEqual(str(context.exception), @@ -82,7 +82,7 @@ class TestFileCheckUtil(TestCase): @patch.object(logger, "error") def test_check_path_readability(self, mock_logger_error): path = "file_path" - with patch("atat.core.common.file_check.os.access", return_value=False): + with patch("msprobe.core.common.file_utils.os.access", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_readability(path) self.assertEqual(str(context.exception), @@ -91,14 +91,14 @@ class TestFileCheckUtil(TestCase): mock_access = MagicMock() mock_access.return_value = True - with patch("atat.core.common.file_check.os.access", new=mock_access): + with patch("msprobe.core.common.file_utils.os.access", new=mock_access): check_path_readability(path) self.assertEqual(mock_access.call_args[0], (path, os.R_OK)) @patch.object(logger, "error") def test_check_path_writability(self, mock_logger_error): path = "file_path" - with patch("atat.core.common.file_check.os.access", return_value=False): + with patch("msprobe.core.common.file_utils.os.access", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_writability(path) self.assertEqual(str(context.exception), @@ -107,14 +107,14 @@ class TestFileCheckUtil(TestCase): mock_access = MagicMock() mock_access.return_value = True - with patch("atat.core.common.file_check.os.access", new=mock_access): + with patch("msprobe.core.common.file_utils.os.access", new=mock_access): check_path_writability(path) self.assertEqual(mock_access.call_args[0], (path, os.W_OK)) @patch.object(logger, "error") def test_check_path_executable(self, mock_logger_error): path = "file_path" - with patch("atat.core.common.file_check.os.access", return_value=False): + with patch("msprobe.core.common.file_utils.os.access", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_executable(path) self.assertEqual(str(context.exception), @@ -123,7 +123,7 @@ class TestFileCheckUtil(TestCase): mock_access = MagicMock() mock_access.return_value = True - with patch("atat.core.common.file_check.os.access", new=mock_access): + with patch("msprobe.core.common.file_utils.os.access", new=mock_access): check_path_executable(path) self.assertEqual(mock_access.call_args[0], (path, os.X_OK)) @@ -135,7 +135,7 @@ class TestFileCheckUtil(TestCase): path = "file_path" mock_stat = TestStat(0o002) - with patch("atat.core.common.file_check.os.stat", return_value=mock_stat): + with patch("msprobe.core.common.file_utils.os.stat", return_value=mock_stat): with self.assertRaises(FileCheckException) as context: check_other_user_writable(path) self.assertEqual(str(context.exception), @@ -147,7 +147,7 @@ class TestFileCheckUtil(TestCase): def test_check_path_owner_consistent(self, mock_logger_error): file_path = os.path.realpath(__file__) file_owner = os.stat(file_path).st_uid - with patch("atat.core.common.file_check.os.getuid", return_value=file_owner+1): + with patch("msprobe.core.common.file_utils.os.getuid", return_value=file_owner+1): with self.assertRaises(FileCheckException) as context: check_path_owner_consistent(file_path) self.assertEqual(str(context.exception), @@ -160,7 +160,7 @@ class TestFileCheckUtil(TestCase): path = "path" mock_re_match = MagicMock() mock_re_match.return_value = False - with patch("atat.core.common.file_check.re.match", new=mock_re_match): + with patch("msprobe.core.common.file_utils.re.match", new=mock_re_match): with self.assertRaises(FileCheckException) as context: check_path_pattern_vaild(path) self.assertEqual(str(context.exception), @@ -181,8 +181,8 @@ class TestFileCheckUtil(TestCase): def test_check_common_file_size(self): mock_check_file_size = MagicMock() - with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ - patch("atat.core.common.file_check.check_file_size", new=mock_check_file_size): + with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=True), \ + patch("msprobe.core.common.file_utils.check_file_size", new=mock_check_file_size): for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): check_common_file_size(suffix) mock_check_file_size.assert_called_with(suffix, max_size) @@ -201,16 +201,16 @@ class TestFileCheckUtil(TestCase): def test_check_path_type(self, mock_logger_error): file_path = "file_path" - with patch("atat.core.common.file_check.os.path.isfile", return_value=False), \ - patch("atat.core.common.file_check.os.path.isdir", return_value=True): + with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=False), \ + patch("msprobe.core.common.file_utils.os.path.isdir", return_value=True): with self.assertRaises(FileCheckException) as context: check_path_type(file_path, FileCheckConst.FILE) self.assertEqual(str(context.exception), FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) mock_logger_error.assert_called_with(f"The {file_path} should be a file!") - with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ - patch("atat.core.common.file_check.os.path.isdir", return_value=False): + with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=True), \ + patch("msprobe.core.common.file_utils.os.path.isdir", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_type(file_path, FileCheckConst.DIR) self.assertEqual(str(context.exception), diff --git a/debug/accuracy_tools/atat/test/core_ut/test_log.py b/debug/accuracy_tools/msprobe/test/core_ut/test_log.py similarity index 92% rename from debug/accuracy_tools/atat/test/core_ut/test_log.py rename to debug/accuracy_tools/msprobe/test/core_ut/test_log.py index 6d7998d5ae068bf5044eb097b92cd5dbc6971e0a..1687c48d025a72a9237a561cc3bea0dd473f6fbd 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_log.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_log.py @@ -17,11 +17,11 @@ from unittest import TestCase from unittest.mock import patch, MagicMock -from atat.core.common.log import BaseLogger, logger +from msprobe.core.common.log import BaseLogger, logger class TestLog(TestCase): - @patch("atat.core.common.log.print") + @patch("msprobe.core.common.log.print") def test__print_log(self, mock_print): logger._print_log("level", "msg") self.assertIn("[level] msg", mock_print.call_args[0][0]) @@ -75,7 +75,7 @@ class TestLog(TestCase): @patch.object(BaseLogger, "get_rank") def test_info_on_rank_0(self, mock_get_rank): mock_print = MagicMock() - with patch("atat.core.common.log.print", new=mock_print): + with patch("msprobe.core.common.log.print", new=mock_print): mock_get_rank.return_value = 0 logger.info_on_rank_0("msg") self.assertIn("[INFO] msg", mock_print.call_args[0][0]) @@ -87,7 +87,7 @@ class TestLog(TestCase): @patch.object(BaseLogger, "get_rank") def test_error_on_rank_0(self, mock_get_rank): mock_print = MagicMock() - with patch("atat.core.common.log.print", new=mock_print): + with patch("msprobe.core.common.log.print", new=mock_print): mock_get_rank.return_value = 0 logger.error_on_rank_0("msg") self.assertIn("[ERROR] msg", mock_print.call_args[0][0]) @@ -99,7 +99,7 @@ class TestLog(TestCase): @patch.object(BaseLogger, "get_rank") def test_warning_on_rank_0(self, mock_get_rank): mock_print = MagicMock() - with patch("atat.core.common.log.print", new=mock_print): + with patch("msprobe.core.common.log.print", new=mock_print): mock_get_rank.return_value = 0 logger.warning_on_rank_0("msg") self.assertIn("[WARNING] msg", mock_print.call_args[0][0]) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py similarity index 75% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py index 47d60999b16a16d7593559c581354d4674438343..7411018ff08507f0ab867b6394aa1c08b5f26469 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_api_kbk_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py @@ -19,9 +19,9 @@ import os from unittest import TestCase from unittest.mock import patch -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.dump.api_kbk_dump import ApiKbkDump +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump class TestApiKbkDump(TestCase): @@ -42,10 +42,10 @@ class TestApiKbkDump(TestCase): self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" - with patch("atat.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \ - patch("atat.mindspore.dump.api_kbk_dump.FileOpen"), \ - patch("atat.mindspore.dump.api_kbk_dump.json.dump"), \ - patch("atat.mindspore.dump.api_kbk_dump.logger.info"): + with patch("msprobe.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \ + patch("msprobe.mindspore.dump.api_kbk_dump.FileOpen"), \ + patch("msprobe.mindspore.dump.api_kbk_dump.json.dump"), \ + patch("msprobe.mindspore.dump.api_kbk_dump.logger.info"): dumper.handle() self.assertEqual(os.environ.get("GRAPH_OP_RUN"), "1") self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py similarity index 87% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py index 3bdf341c3979aa0f7e98cab468dcdcb090215137..5187d3951c0cbf2bdeb5db6f402933c8bf08e94d 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py @@ -16,9 +16,9 @@ """ from unittest import TestCase -from atat.core.common.const import Const -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.core.common.const import Const +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig class TestDebuggerConfig(TestCase): @@ -27,7 +27,7 @@ class TestDebuggerConfig(TestCase): "dump_path": "/absolute_path", "rank": [], "step": [], - "level": "L1" + "level": "L0" } common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py similarity index 89% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py index f6626f551fec02438e39ed474374654201e6204c..fb88d7bbbf328b0b8a61b11d41808b756881510e 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py @@ -16,9 +16,9 @@ """ from unittest import TestCase -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.dump.dump_tool_factory import DumpToolFactory +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory class TestDumpToolFactory(TestCase): diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py similarity index 80% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py index 6c59521a17d57585170753f4d935b6109b920595..e691a2c7edde2feb1f2c1d60fbba275724bb9092 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_dump.py @@ -19,9 +19,9 @@ import os from unittest import TestCase from unittest.mock import patch -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump class TestKernelGraphDump(TestCase): @@ -45,10 +45,10 @@ class TestKernelGraphDump(TestCase): self.assertEqual(dumper.dump_json["common_dump_settings"]["file_format"], "bin") self.assertEqual(dumper.dump_json["common_dump_settings"]["input_output"], 2) - with patch("atat.mindspore.dump.kernel_graph_dump.make_dump_path_if_not_exists"), \ - patch("atat.mindspore.dump.kernel_graph_dump.FileOpen"), \ - patch("atat.mindspore.dump.kernel_graph_dump.json.dump"), \ - patch("atat.mindspore.dump.kernel_graph_dump.logger.info"): + with patch("msprobe.mindspore.dump.kernel_graph_dump.make_dump_path_if_not_exists"), \ + patch("msprobe.mindspore.dump.kernel_graph_dump.FileOpen"), \ + patch("msprobe.mindspore.dump.kernel_graph_dump.json.dump"), \ + patch("msprobe.mindspore.dump.kernel_graph_dump.logger.info"): os.environ["GRAPH_OP_RUN"] = "1" with self.assertRaises(Exception) as context: diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py similarity index 76% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py index 101482458dc0a901e0066937ae6df9e9a23fe4fc..a93fab021ab59beff9016895a1748eb942274e30 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_kernel_graph_overflow_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py @@ -19,9 +19,9 @@ import os from unittest import TestCase from unittest.mock import patch -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.overflow_check.kernel_graph_overflow_check import KernelGraphOverflowCheck class TestKernelGraphOverflowCheck(TestCase): @@ -43,10 +43,10 @@ class TestKernelGraphOverflowCheck(TestCase): self.assertEqual(checker.dump_json["common_dump_settings"]["op_debug_mode"], 2) os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" - with patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.make_dump_path_if_not_exists"), \ - patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.FileOpen"), \ - patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.json.dump"), \ - patch("atat.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"): + with patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.make_dump_path_if_not_exists"), \ + patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.FileOpen"), \ + patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.json.dump"), \ + patch("msprobe.mindspore.overflow_check.kernel_graph_overflow_check.logger.info"): os.environ["GRAPH_OP_RUN"] = "1" with self.assertRaises(Exception) as context: diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py similarity index 83% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py index 3dc3670128c1c6b74b0f67fe43b008df340dc858..30212d95e621bea516b888e7e61d042990a2c93a 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py @@ -17,9 +17,9 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.common.const import Const -from atat.mindspore.ms_config import (parse_json_config, parse_task_config, - TensorConfig, StatisticsConfig, OverflowCheck) +from msprobe.core.common.const import Const +from msprobe.mindspore.ms_config import (parse_json_config, parse_task_config, + TensorConfig, StatisticsConfig, OverflowCheckConfig) class TestMsConfig(TestCase): @@ -37,8 +37,8 @@ class TestMsConfig(TestCase): "summary_mode": "statistics" } } - with patch("atat.mindspore.ms_config.FileOpen", mock_open(read_data='')), \ - patch("atat.mindspore.ms_config.json.load", return_value=mock_json_data): + with patch("msprobe.mindspore.ms_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.mindspore.ms_config.json.load", return_value=mock_json_data): common_config, task_config = parse_json_config("./config.json") self.assertEqual(common_config.task, Const.STATISTICS) self.assertEqual(task_config.data_mode, ["all"]) @@ -62,7 +62,7 @@ class TestMsConfig(TestCase): self.assertTrue(isinstance(task_config, StatisticsConfig)) task_config = parse_task_config("overflow_check", mock_json_config) - self.assertTrue(isinstance(task_config, OverflowCheck)) + self.assertTrue(isinstance(task_config, OverflowCheckConfig)) with self.assertRaises(Exception) as context: parse_task_config("free_benchmark", mock_json_config) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py similarity index 88% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py index 497fe1376abcff0607d6afd3f2b03d94f963bcd7..47da051d4fdd1d9b65ef8c6092a7b05a3e2263b6 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_overflow_check_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py @@ -16,9 +16,9 @@ """ from unittest import TestCase -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory class TestOverflowCheckToolFactory(TestCase): diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py similarity index 79% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py index 834a58e41a426d975ac97b0f757db5a0432a297f..425ed3040dcc829927a8f4cbb25024f1b567a48f 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py @@ -17,9 +17,9 @@ from unittest import TestCase from unittest.mock import patch -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.debugger.precision_debugger import PrecisionDebugger +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger class TestPrecisionDebugger(TestCase): @@ -35,16 +35,16 @@ class TestPrecisionDebugger(TestCase): "dump_path": "/absolute_path", "rank": [], "step": [], - "level": "L1" + "level": "L0" } common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) handler = Handler() - with patch("atat.mindspore.debugger.precision_debugger.parse_json_config", + with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", return_value=[common_config, task_config]), \ - patch("atat.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler): + patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler): debugger = PrecisionDebugger() debugger.start() self.assertTrue(isinstance(debugger.config, DebuggerConfig)) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py similarity index 82% rename from debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py index 02cd9934cb1635b9cec68bcd4b54dd2753ea23e9..41be7b1db6c7d723aaeec1607f564ac3d772b404 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py @@ -17,10 +17,10 @@ from unittest import TestCase from unittest.mock import patch -from atat.core.common_config import CommonConfig, BaseConfig -from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.mindspore.dump.kernel_graph_dump import KernelGraphDump -from atat.mindspore.task_handler_factory import TaskHandlerFactory +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump +from msprobe.mindspore.task_handler_factory import TaskHandlerFactory class TestTaskHandlerFactory(TestCase): @@ -47,7 +47,7 @@ class TestTaskHandlerFactory(TestCase): handler = TaskHandlerFactory.create(config) self.assertTrue(isinstance(handler, KernelGraphDump)) - with patch("atat.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks): + with patch("msprobe.mindspore.task_handler_factory.TaskHandlerFactory.tasks", new=tasks): with self.assertRaises(Exception) as context: TaskHandlerFactory.create(config) self.assertEqual(str(context.exception), "Can not find task handler") diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/advisor/test_advisor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py similarity index 85% rename from debug/accuracy_tools/atat/test/pytorch_ut/advisor/test_advisor.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py index 78e5b489e7ad14f2965b813d733a30c8849b8a71..176b80068f70e60a06a6eed77b23b35e8b48a50d 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/advisor/test_advisor.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py @@ -2,12 +2,13 @@ import difflib import os import shutil import unittest +import logging from unittest.mock import patch import pandas -from atat.pytorch.advisor.advisor import Advisor -from atat.pytorch.advisor.advisor_const import AdvisorConst +from msprobe.pytorch.advisor.advisor import Advisor +from msprobe.pytorch.advisor.advisor_const import AdvisorConst class TestAdvisor(unittest.TestCase): @@ -70,11 +71,11 @@ class TestAdvisor(unittest.TestCase): output_content = out_file.read().splitlines() result = list(difflib.unified_diff(standard_content, output_content, n=0)) if result: - print('\n\n-------------------------------------------------------------------------', flush=True) - print(f'[ERROR] {output_file.replace(self.output_path, "")} advisor summary are inconsistent.', - flush=True) - print('\n'.join(result), flush=True) - print('-------------------------------------------------------------------------', flush=True) + logging.basicConfig(level=logging.INFO) + logging.info('\n\n-------------------------------------------------------------------------') + logging.error(f'[ERROR] {output_file.replace(self.output_path, "")} advisor summary are inconsistent.') + logging.error('\n'.join(result)) + logging.info('\n\n-------------------------------------------------------------------------') self.has_error = True diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py similarity index 96% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py index 16d0c0bc12738bb7ce129224cf124761503c31fd..56d100f0a1b570c0c0a1753db3c79aacea7b76ac 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py @@ -1,12 +1,12 @@ import unittest from unittest.mock import patch -from atat.pytorch.api_accuracy_checker.common.utils import * +from msprobe.pytorch.api_accuracy_checker.common.utils import * class TestUtils(unittest.TestCase): - @patch('atat.pytorch.api_accuracy_checker.common.utils.get_file_content_bytes') + @patch('msprobe.pytorch.api_accuracy_checker.common.utils.get_file_content_bytes') def test_get_json_contents_should_raise_exception(self, mock_get_file_content_bytes): mock_get_file_content_bytes.return_value = 'not a dict' with self.assertRaises(CompareException) as ce: diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py similarity index 92% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_config.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py index 066e74aa518dd7958a511bb47be92bce7ce5ac0b..35fc6164763e685d09e737e7f85bec33623ec111 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/common/test_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py @@ -2,7 +2,7 @@ import unittest import os from unittest.mock import patch -from atat.pytorch.api_accuracy_checker.common.config import Config +from msprobe.pytorch.api_accuracy_checker.common.config import Config class TestConfig(unittest.TestCase): @@ -35,5 +35,5 @@ class TestConfig(unittest.TestCase): validate_white_list = ['conv1d', 'max_pool1d', 'dropout', '__add__'] self.assertEqual(self.cfg.validate('white_list', validate_white_list), validate_white_list) - with self.assertRaises(ValueError): + with self.assertRaises(Exception): self.cfg.validate('white_list', ['invalid_api1', 'invalid_api2']) diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py similarity index 98% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py index 9604e7a681c869c41cdfa70b7b1b551ceed9604e..35a8b9f1fa52f689905986ca477c8a7077a084da 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py @@ -2,7 +2,7 @@ import unittest import numpy as np -from atat.pytorch.api_accuracy_checker.compare import algorithm as alg +from msprobe.pytorch.api_accuracy_checker.compare import algorithm as alg class TestAlgorithmMethods(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py similarity index 94% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py index 7717d826577cbc3ade87c6188e1049c7455df905..540460d0896532bbe242c3eb30b3ee945bae9571 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py @@ -2,14 +2,14 @@ import unittest import pandas as pd -from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import ( +from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import ( CompareConfig, BenchmarkStandard, check_csv_columns, check_error_rate, get_api_checker_result, ) -from atat.core.common.const import CompareConst +from msprobe.core.common.const import CompareConst class TestApiPrecisionCompare(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py similarity index 96% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py index 2c97471c7ab2aec4a6db25d6d8dca3cd07001cda..e1e6d51de292cf4d8b617ab73db67ff4920bfac3 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py @@ -7,9 +7,9 @@ import unittest import numpy as np import torch.nn.functional -from atat.pytorch.api_accuracy_checker.compare.compare import Comparator -from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn -from atat.pytorch.api_accuracy_checker.run_ut.run_ut import UtDataInfo +from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import UtDataInfo current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py similarity index 68% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py index ee25a25e74d18dd0cc5436747767aa2acbbab05e..782321868a8cbcae9ffed3b215ca068968b1c0ae 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py @@ -1,6 +1,6 @@ import unittest -from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn class TestCompareColumns(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py similarity index 88% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py index 93f3c2c73e14130f479198a32a6b4f2f5c64745d..ac9c974ea3ecf6a835ce448d754582d435548ed8 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py @@ -2,8 +2,8 @@ import unittest import numpy as np -from atat.pytorch.api_accuracy_checker.common.utils import CompareException -from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, convert_str_to_float +from msprobe.pytorch.api_accuracy_checker.common.utils import CompareException +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, convert_str_to_float class TestCompareUtils(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json similarity index 100% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json similarity index 100% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py similarity index 96% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py index f47c71c984f7e03449c8b1bcee11d4e81e64f842..f664dad197f6bbaaed3d574f657552377b176dec 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py @@ -3,8 +3,8 @@ import os import unittest import copy -from atat.pytorch.api_accuracy_checker.run_ut.data_generate import * -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents +from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import * +from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents base_dir = os.path.dirname(os.path.realpath(__file__)) forward_file = os.path.join(base_dir, "forward.json") diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py similarity index 83% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py index 6a9071f15ea982ba8b515c4e95683be586b39aef..27126cdddda215ea521b5090782cc1c85f07e5f0 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py @@ -5,7 +5,7 @@ import logging from unittest.mock import patch, mock_open, MagicMock import json import signal -from atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, \ +from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, \ prepare_config, main, ParallelUTConfig @@ -20,7 +20,7 @@ class TestMultiRunUT(unittest.TestCase): {'key3': 'TRUE', 'key4': 'TRUE'} ] - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileOpen') + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileOpen') def test_split_json_file(self, mock_FileOpen): mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value num_splits = 2 @@ -48,7 +48,7 @@ class TestMultiRunUT(unittest.TestCase): device_id=[0, 1], result_csv_path='result.csv', total_items=2, - real_data_path=None + config_path=None ) mock_file.side_effect = [ @@ -63,10 +63,10 @@ class TestMultiRunUT(unittest.TestCase): @patch('os.remove') @patch('os.path.realpath', side_effect=lambda x: x) - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_link') - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix') - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileChecker') - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.split_json_file', + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_link') + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix') + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.FileChecker') + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.split_json_file', return_value=(['forward_split1.json', 'forward_split2.json'], 2)) def test_prepare_config(self, mock_split_json_file, mock_FileChecker, mock_check_file_suffix, mock_check_link, mock_realpath, mock_remove): @@ -81,7 +81,7 @@ class TestMultiRunUT(unittest.TestCase): args.jit_compile = False args.device_id = [0, 1] args.result_csv_path = None - args.real_data_path = None + args.config_path = None config = prepare_config(args) @@ -93,8 +93,8 @@ class TestMultiRunUT(unittest.TestCase): @patch('argparse.ArgumentParser.parse_args') - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.prepare_config') - @patch('atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut') + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.prepare_config') + @patch('msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut') def test_main(self, mock_run_parallel_ut, mock_prepare_config, mock_parse_args): main() mock_parse_args.assert_called() diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py similarity index 95% rename from debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py index 97dccd2b58f7005cdd8c3907ff150d9df7f9ff3d..bc643794ab692ff5d21bcf412450f845d01f662d 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py @@ -4,8 +4,8 @@ import copy import unittest import torch from unittest.mock import patch, DEFAULT -from atat.pytorch.api_accuracy_checker.run_ut.run_ut import * -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import * +from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents base_dir = os.path.dirname(os.path.realpath(__file__)) forward_file = os.path.join(base_dir, "forward.json") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..288e259c0aae104a62054af3813b7831ec7722f7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py @@ -0,0 +1,267 @@ +# coding=utf-8 +import unittest +import pandas as pd +from msprobe.pytorch.compare import acc_compare as compare + +npu_dict = {'op_name': ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.1', + 'Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_output'], + 'input_struct': [('torch.float32', [1, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]), + ('torch.float32', [16])], + 'output_struct': [('torch.float32', [1, 16, 28, 28])], + 'summary': [[3.029174327850342, -2.926689624786377, -0.06619918346405029], + [0.19919930398464203, -0.19974489510059357, 0.006269412115216255], + [0.19734230637550354, -0.18177609145641327, 0.007903944700956345], + [2.1166646480560303, -2.190781354904175, -0.003579073818400502]], 'stack_info': []} + +bench_dict = {'op_name': ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.1', + 'Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_output'], + 'input_struct': [('torch.float32', [1, 1, 28, 28]), ('torch.float32', [16, 1, 5, 5]), + ('torch.float32', [16])], + 'output_struct': [('torch.float32', [1, 16, 28, 28])], + 'summary': [[3.029174327850342, -2.926689624786377, -0.06619918346405029], + [0.19919930398464203, -0.19974489510059357, 0.006269412115216255], + [0.19734230637550354, -0.18177609145641327, 0.007903944700956345], + [2.1166646480560303, -2.190781354904175, -0.003579073818400502]], 'stack_info': []} + +tensor_list = [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], 'Max': 0.33033010363578796, + 'Min': -0.331031858921051,'Mean': -0.030964046716690063, 'Norm': 2.2533628940582275, 'requires_grad': True, + 'full_op_name': 'Tensor.add_.0.forward_input.0'}, + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, + 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_.0.forward_input.1'}, + {'full_op_name': 'Tensor.add_.0.forward_input.alpha.0', 'dtype': "", "shape": '[]', 'md5': None, + 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1'}, + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, + 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_.0.forward_output.0'} +] + +result_op_dict = {'op_name': ['Tensor.add_.0.forward_input.0', 'Tensor.add_.0.forward_input.1', + 'Tensor.add_.0.forward_input.alpha.0', 'Tensor.add_.0.forward_output.0'], + 'input_struct': [('torch.float32', [16, 1, 3, 3]), ('torch.float32', [16, 1, 3, 3]), + ("", '[]')], + 'output_struct': [('torch.float32', [16, 1, 3, 3])], + 'summary': [[0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275], + [0.003992878366261721, -0.008102823048830032, -0.0002002553956117481, 0.02844562754034996], + [-0.1, -0.1, -0.1, -0.1], + [0.33033010363578796, -0.331031858921051, -0.030964046716690063, 2.2533628940582275]], + 'stack_info': []} + +o_result = [ + ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.0', 'torch.float32', 'torch.float32', + [1, 1, 28, 28], [1, 1, 28, 28], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 3.029174327850342, -2.926689624786377, + -0.06619918346405029, 3.029174327850342, -2.926689624786377, -0.06619918346405029, '', '', 'None'], + ['Functional_conv2d_0_forward_input.1', 'Functional_conv2d_0_forward_input.1', 'torch.float32', 'torch.float32', + [16, 1, 5, 5], [16, 1, 5, 5], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 0.19919930398464203, -0.19974489510059357, + 0.006269412115216255, 0.19919930398464203, -0.19974489510059357, 0.006269412115216255, '', '', 'None'], + ['Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_input.2', 'torch.float32', 'torch.float32', + [16], [16], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, + 0.19734230637550354, -0.18177609145641327, 0.007903944700956345, '', '', 'None'], + ['Functional_conv2d_0_forward_output', 'Functional_conv2d_0_forward_output', 'torch.float32', 'torch.float32', + [1, 16, 28, 28], [1, 16, 28, 28], 0.0, 0.0, 0.0, ' ', '0.0%', '0.0%', '0.0%', ' ', 2.1166646480560303, -2.190781354904175, + -0.003579073818400502, 2.1166646480560303, -2.190781354904175, -0.003579073818400502, '', '', 'None']] + +npu_dict_aten = {'op_name': ['Aten__native_batch_norm_legit_functional.default_0_forward_input.0', + 'Aten__native_batch_norm_legit_functional.default_0_forward_input.1', + 'Aten__native_batch_norm_legit_functional.default_0_forward_input.2', + 'Aten__native_batch_norm_legit_functional.default_0_forward_input.3', + 'Aten__native_batch_norm_legit_functional.default_0_forward_input.4', + 'Aten__native_batch_norm_legit_functional.default_0_forward_output.0', + 'Aten__native_batch_norm_legit_functional.default_0_forward_output.1', + 'Aten__native_batch_norm_legit_functional.default_0_forward_output.2', + 'Aten__native_batch_norm_legit_functional.default_0_forward_output.3', + 'Aten__native_batch_norm_legit_functional.default_0_forward_output.4'], + 'input_struct': [('torch.float16', [256, 256, 14, 14]), ('torch.float32', [256]), + ('torch.float32', [256]), ('torch.float32', [256]), ('torch.float32', [256])], + 'output_struct': [('torch.float16', [256, 256, 14, 14]), ('torch.float32', [256]), + ('torch.float32', [256]), ('torch.float32', [256]), ('torch.float32', [256])], + 'summary': [[139.625, -127.5625, -0.0103607177734375], + [2.5276029109954834, -2.1788690090179443, -0.0008259844034910202], + [2.472219944000244, -2.845968723297119, -0.008756577968597412], + [2.763145923614502, -3.398397922515869, -0.052132632583379745], + [2.673110008239746, -3.149275064468384, 0.01613386906683445], + [13.5546875, -10.640625, -0.008758544921875], + [0.30550330877304077, -0.24485322833061218, -0.010361209511756897], + [623.9192504882812, 432.96826171875, 520.2276611328125], + [2.4797861576080322, -3.055997371673584, -0.04795549064874649], + [61.7945556640625, 42.59713363647461, 52.03831481933594]]} + +bench_dict_functional = { + 'op_name': ['Functional_batch_norm_0_forward_input.0', 'Functional_batch_norm_0_forward_input.1', + 'Functional_batch_norm_0_forward_input.2', 'Functional_batch_norm_0_forward_input.3', + 'Functional_batch_norm_0_forward_input.4', 'Functional_batch_norm_0_forward_output'], + 'input_struct': [('torch.float32', [256, 256, 14, 14]), ('torch.float32', [256]), ('torch.float32', [256]), + ('torch.float32', [256]), ('torch.float32', [256])], + 'output_struct': [('torch.float32', [256, 256, 14, 14])], + 'summary': [[3.061628818511963, -3.22507381439209, 3.634914173744619e-05], + [0.0005779837374575436, -0.0006301702815108001, 3.634906533989124e-06], + [0.9338104128837585, 0.9277191162109375, 0.930335283279419], + [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], + [5.397906303405762, -5.796811580657959, 2.5283952709287405e-10]] +} + +aten_result = [ + ['Aten__native_batch_norm_legit_functional.default_0_forward_input.0', 'Functional_batch_norm_0_forward_input.0', + 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 136.56337118148804, -124.33742618560791, + -0.010397066915174946, ' ', '4460.480981749501%', '3855.335826136584%', '28603.33536971545%', ' ', 139.625, + -127.5625, -0.0103607177734375, 3.061628818511963, -3.22507381439209, 3.634914173744619e-05, 'Warning', + 'Need double check api accuracy.', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_input.1', 'Functional_batch_norm_0_forward_input.1', + 'torch.float32', 'torch.float32', [256], [256], 2.527024927258026, -2.1782388387364335, -0.0008296193100250093, + ' ', '437213.84590749856%', '345658.76916858414%', '22823.676544842117%', ' ', 2.5276029109954834, + -2.1788690090179443, -0.0008259844034910202, 0.0005779837374575436, -0.0006301702815108001, 3.634906533989124e-06, + 'Warning', 'Need double check api accuracy.', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_input.2', 'Functional_batch_norm_0_forward_input.2', + 'torch.float32', 'torch.float32', [256], [256], 1.5384095311164856, -3.7736878395080566, -0.9390918612480164, ' ', + '164.74538192025793%', '406.7705163736246%', '100.94122819224167%', ' ', 2.472219944000244, -2.845968723297119, + -0.008756577968597412, 0.9338104128837585, 0.9277191162109375, 0.930335283279419, 'Warning', + 'Need double check api accuracy.', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_input.3', 'Functional_batch_norm_0_forward_input.3', + 'torch.float32', 'torch.float32', [256], [256], 1.763145923614502, -4.398397922515869, -1.0521326325833797, ' ', + '176.3145923614502%', '439.8397922515869%', '105.21326325833797%', ' ', 2.763145923614502, -3.398397922515869, + -0.052132632583379745, 1.0, 1.0, 1.0, 'Warning', 'Need double check api accuracy.', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_input.4', 'Functional_batch_norm_0_forward_input.4', + 'torch.float32', 'torch.float32', [256], [256], 2.673110008239746, -3.149275064468384, 0.01613386906683445, ' ', + 'N/A', 'N/A', 'N/A', ' ', 2.673110008239746, -3.149275064468384, 0.01613386906683445, 0.0, 0.0, 0.0, 'Warning', + 'Need double check api accuracy.', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_output.0', 'Functional_batch_norm_0_forward_output', + 'torch.float16', 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 8.156781196594238, -4.843813419342041, + -0.008758545174714527, ' ', '151.11009228611078%', '83.55995967687207%', '3464072756.115108%', ' ', 13.5546875, + -10.640625, -0.008758544921875, 5.397906303405762, -5.796811580657959, 2.5283952709287405e-10, 'Warning', + 'Need double check api accuracy.', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_output.1', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', + ' ', ' ', ' ', ' ', ' ', 0.30550330877304077, -0.24485322833061218, -0.010361209511756897, 'Nan', 'Nan', 'Nan', + 'Yes', '', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_output.2', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', + ' ', ' ', ' ', ' ', ' ', 623.9192504882812, 432.96826171875, 520.2276611328125, 'Nan', 'Nan', 'Nan', + 'Yes', '', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_output.3', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', + ' ', ' ', ' ', ' ', ' ', 2.4797861576080322, -3.055997371673584, -0.04795549064874649, 'Nan', 'Nan', 'Nan', + 'Yes', '', 'None'], + ['Aten__native_batch_norm_legit_functional.default_0_forward_output.4', 'Nan', 'torch.float32', 'Nan', [256], 'Nan', + ' ', ' ', ' ', ' ', ' ', 61.7945556640625, 42.59713363647461, 52.03831481933594, 'Nan', 'Nan', 'Nan', + 'Yes', '', 'None']] + +highlight_dict = {'red_rows': [], 'yellow_rows': []} + +num_0, num_1, num_2, num_3 = 0, 1, 2, 3 +summary_line_input = ['Functional_batch_norm_0_forward_input.0', 'Functional_batch_norm_0_forward_input.0', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.01, 0, 0, 0, 1, 1, 1, 1, 1.01, 1, 1, 1, + 'Yes', ''] +summary_line_1 = ['Functional_batch_norm_0_forward_output.0', 'Functional_batch_norm_0_forward_output.0', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 10, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, + 'Warning', ''] +summary_line_2 = ['Functional_batch_norm_0_forward_output.1', 'Functional_batch_norm_0_forward_output.1', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.02, 0, 0, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, + 'Warning', ''] +summary_line_3 = ['Functional_batch_norm_0_forward_output.2', 'Functional_batch_norm_0_forward_output.2', + 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0, 0, 0, 0, 2, 0, 1, 1, 1, 1, 1, 1, + 'Warning', ''] +line_input = ['Functional_batch_norm_0_forward_input.0', 'Functional_batch_norm_0_forward_input.0', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 1, 1, 1, 0.95, 1, 1, 1, 1, 1, 1.01, 1, 1, 1, + 'Yes', ''] +line_1 = ['Functional_batch_norm_0_forward_output.0', 'Functional_batch_norm_0_forward_output.0', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1, 1, 0.59, 1, 'nan', 0, 1, 1, 19, 1, 1, 1, + 'Warning', ''] +line_2 = ['Functional_batch_norm_0_forward_output.1', 'Functional_batch_norm_0_forward_output.1', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.9, 1, 1, 0.8, 1, 0, 0.12, 0, 1, 1, 0.1, 1, 1, 1, + 'Warning', ''] +line_3 = ['Functional_batch_norm_0_forward_output.2', 'Functional_batch_norm_0_forward_output.2', 'torch.float16', + 'torch.float32', [256, 256, 14, 14], [256, 256, 14, 14], 0.8, 1.1e+10, 1, 0.85, 1, 9, 0.12, 0, 1, 1, 0.1, 1, + 1, 1, 'Warning', ''] + +op_data = { + 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.33033010363578796, 'Min': -0.331031858921051,'Mean': -0.030964046716690063, + 'Norm': 2.2533628940582275, 'requires_grad': True}, + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, + 'Norm': 0.02844562754034996, 'requires_grad': False}], + 'input_kwargs': {'alpha': {'type': 'float', 'value': -0.1}}, + 'output': [{'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.33033010363578796, 'Min': -0.331031858921051,'Mean': -0.030964046716690063, + 'Norm': 2.2533628940582275, 'requires_grad': True}]} + +op_name = "Tensor.add_0.0.forward" + +op_result = [ + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, + 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.forward_input.0'}, + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.003992878366261721, 'Min': -0.008102823048830032, 'Mean': -0.0002002553956117481, + 'Norm': 0.02844562754034996, 'requires_grad': False, 'full_op_name': 'Tensor.add_0.0.forward_input.1'}, + {'full_op_name': 'Tensor.add_0.0.forward_input.alpha.0', 'dtype': "", 'shape': '[]', 'md5': None, + 'Max': -0.1, 'Min': -0.1, 'Mean': -0.1, 'Norm': -0.1, 'data_name': '-1'}, + {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [16, 1, 3, 3], + 'Max': 0.33033010363578796, 'Min': -0.331031858921051, 'Mean': -0.030964046716690063, + 'Norm': 2.2533628940582275, 'requires_grad': True, 'full_op_name': 'Tensor.add_0.0.forward_output.0'}] + + +class TestUtilsMethods(unittest.TestCase): + + def test_check_graph_mode(self): + op1 = "Aten" + op2 = "torch" + self.assertTrue(compare.check_graph_mode(op1, op2)) + self.assertTrue(compare.check_graph_mode(op2, op1)) + self.assertFalse(compare.check_graph_mode(op1, op1)) + self.assertFalse(compare.check_graph_mode(op2, op2)) + + def test_check_op(self): + fuzzy_match = False + result = compare.check_op(npu_dict, bench_dict, fuzzy_match) + self.assertEqual(result, True) + + def test_merge_tensor(self): + op_dict = compare.merge_tensor(tensor_list, True, False) + self.assertEqual(op_dict, result_op_dict) + + def test_read_op(self): + result = compare.read_op(op_data, op_name) + self.assertEqual(result, op_result) + + def test_match_op(self): + fuzzy_match = False + a, b = compare.match_op([npu_dict], [bench_dict], fuzzy_match) + self.assertEqual(a, 0) + self.assertEqual(b, 0) + + def test_get_accuracy(self): + result = [] + compare.get_accuracy(result, npu_dict, bench_dict, highlight_dict) + self.assertEqual(result, o_result) + + def test_get_accuracy_graph_mode(self): + result = [] + compare.get_accuracy(result, npu_dict_aten, bench_dict_functional, highlight_dict) + self.assertEqual(result, aten_result) + + def test_find_error_rows(self): + summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] + highlight_dict = {'red_rows': [], 'yellow_rows': []} + compare.find_error_rows(summary_result, 0, 1, highlight_dict, summary_compare=True) + self.assertEqual(highlight_dict, {'red_rows': [], 'yellow_rows': []}) + + def test_find_compare_result_error_rows(self): + result = [line_input, line_1, line_2, line_3] + result_df = pd.DataFrame(result) + highlight_dict = {'red_rows': [], 'yellow_rows': []} + compare.find_compare_result_error_rows(result_df, highlight_dict, False, False) + self.assertEqual(highlight_dict, {'red_rows': [num_1, num_3], 'yellow_rows': [num_2]}) + + def test_rename_api(self): + test_name_1 = "Distributed.broadcast.0.forward.input.0" + expect_name_1 = "Distributed.broadcast.input.0" + actual_name_1 = compare.rename_api(test_name_1, "forward") + self.assertEqual(actual_name_1, expect_name_1) + + test_name_2 = "Torch.sum.0.backward.output.0" + expect_name_2 = "Torch.sum.output.0" + actual_name_2 = compare.rename_api(test_name_2, "backward") + self.assertEqual(actual_name_2, expect_name_2) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py new file mode 100644 index 0000000000000000000000000000000000000000..ac28e994e9c8e77f8ae675fec3322eaf64a64321 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_match.py @@ -0,0 +1,20 @@ +# coding=utf-8 +import unittest +from msprobe.pytorch.compare import match + + +class TestMatch(unittest.TestCase): + def test_graph_mapping(self): + op1 = "Aten_convolution_1_forward_0.input.0" + op2 = "Torch_conv2d_0_forward_0.input.0" + op3 = "Torch_batch_norm_0_forward_0.input.0" + op4 = "Aten_convolution.default_1_forward_0.input.0" + op5 = "Aten_foo_1_forward_0.input.0" + self.assertTrue(match.graph_mapping.match(op1, op2)) + self.assertTrue(match.graph_mapping.match(op2, op1)) + self.assertTrue(match.graph_mapping.match(op4, op2)) + self.assertTrue(match.graph_mapping.match(op2, op4)) + self.assertFalse(match.graph_mapping.match(op1, op3)) + self.assertFalse(match.graph_mapping.match(op3, op1)) + self.assertFalse(match.graph_mapping.match(op5, op2)) + self.assertFalse(match.graph_mapping.match(op2, op5)) diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py similarity index 94% rename from debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py index 828d646c52f363e758a10939dbfcd98d942eb9f2..ad9eb5cd0ed5b2eaaff9745c8af9ced8dc1ab883 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py @@ -1,10 +1,10 @@ from unittest import TestCase import torch -from atat.core.common.const import Const -from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode -from atat.pytorch.free_benchmark.common.params import data_pre_deal -from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode +from msprobe.pytorch.free_benchmark.common.params import data_pre_deal +from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory class TestPerturbedLayer(TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py similarity index 91% rename from debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py index d46e26e09488d481ac9e96a8b4002dd3b62446bd..399efeb42d7cd7e7e34dd472cd8a9d82c26a5b5e 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py @@ -2,17 +2,17 @@ from abc import ABC from unittest import TestCase import torch -from atat.core.common.const import Const -from atat.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig -from atat.pytorch.free_benchmark.common.counter import preheat_counter -from atat.pytorch.free_benchmark.common.enums import ( +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig +from msprobe.pytorch.free_benchmark.common.counter import preheat_counter +from msprobe.pytorch.free_benchmark.common.enums import ( DeviceType, FuzzLevel, HandlerType, PerturbationMode, ) -from atat.pytorch.free_benchmark.common.params import DataParams, make_handler_params -from atat.pytorch.free_benchmark.result_handlers.handler_factory import ( +from msprobe.pytorch.free_benchmark.common.params import DataParams, make_handler_params +from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, ) diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py similarity index 92% rename from debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py index d326e993c07d66a1baf5ae785ed4b519624bb982..4498a2af7054edd89aa6fae6a057a489216794b6 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py @@ -4,10 +4,10 @@ from unittest import TestCase import torch import torch.nn as nn -from atat.core.common.const import Const -from atat.pytorch.free_benchmark import FreeBenchmarkCheck -from atat.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig -from atat.pytorch.free_benchmark.common.enums import ( +from msprobe.core.common.const import Const +from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck +from msprobe.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig +from msprobe.pytorch.free_benchmark.common.enums import ( DeviceType, FuzzLevel, HandlerType, diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py new file mode 100644 index 0000000000000000000000000000000000000000..d67adf2f91292391ff01d450bb5647524f6fc9c4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/functional/test_dump_module.py @@ -0,0 +1,15 @@ +import unittest + +import torch.nn as nn +from msprobe.pytorch import PrecisionDebugger +from msprobe.pytorch.functional.dump_module import module_dump, module_count + + +class TestDumpModule(unittest.TestCase): + def setUp(self): + self.module = nn.Linear(in_features=8, out_features=4) + + def test_module_dump(self): + PrecisionDebugger(dump_path="./dump") + module_dump(self.module, "TestModule") + self.assertTrue("TestModule" in module_count) diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_api_registry.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py similarity index 91% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_api_registry.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py index c80e5dbed456fd7df9c574a14f96e16b886e8a3e..837ad23df76be2a012a7408dab4879847937f229 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_api_registry.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_api_registry.py @@ -1,5 +1,5 @@ import unittest -from atat.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu +from msprobe.pytorch.hook_module.api_registry import ApiRegistry, torch_version_above_2, is_gpu class TestApiRegistry(unittest.TestCase): @@ -43,7 +43,7 @@ class TestApiRegistry(unittest.TestCase): import torch import torch.distributed as dist #import torch_npu #门禁没有安装torch_npu - from atat.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 + from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 @@ -79,7 +79,7 @@ class TestApiRegistry(unittest.TestCase): import torch import torch.distributed as dist #import torch_npu #门禁没有安装torch_npu - from atat.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 + from msprobe.pytorch.hook_module.api_registry import torch_without_guard_version, npu_distributed_api, is_gpu, torch_version_above_2 diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py similarity index 94% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_hook_module.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py index 646f64152267702d10eb0ff9aa2bce4e497a9f35..50783e5d736c024b03f20008ad6b72882eddcd87 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_hook_module.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch, Mock -from atat.pytorch.hook_module.hook_module import HOOKModule +from msprobe.pytorch.hook_module.hook_module import HOOKModule class TestHookModule(unittest.TestCase): def test_call_1(self): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py similarity index 96% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_aten.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py index 92aee790ddd4d28e652cc063209028a1e4e5b3d1..4940b07cb0d8e9d2283db3daebc910a7fdcd6ce9 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_aten.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py @@ -1,6 +1,6 @@ import unittest import torch -from atat.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate +from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate def hook(name): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py similarity index 95% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_distributed.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py index bd0501ef2fdfa90b00f9049c00aceb03e5225831..9a375e45bfcdc93ac36fb9d44a79f50fea7932d5 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py @@ -1,6 +1,6 @@ import unittest import torch.distributed as dist -from atat.pytorch.hook_module.wrap_distributed import * +from msprobe.pytorch.hook_module.wrap_distributed import * class TestWrapDistributed(unittest.TestCase): def hook(name, prefix): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_functional.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py similarity index 91% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_functional.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py index 232117498b5900b1d7d178208b81430f86b52eb4..f43b8ea6cb98dd1947811b7b0641439b225b51ec 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_functional.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py @@ -1,6 +1,6 @@ import unittest import torch -from atat.pytorch.hook_module import wrap_functional as wf +from msprobe.pytorch.hook_module import wrap_functional as wf class TestWrapFunctional(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py similarity index 88% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_tensor.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py index e027270540e6a5ebfc1d3963ac29bae87c9b1110..61f76b0ca0a59ee680ff40991fa9cba5e42d869d 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_tensor.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py @@ -1,7 +1,7 @@ import unittest import torch import yaml -from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind +from msprobe.pytorch.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind class TestWrapTensor(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py similarity index 96% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_torch.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py index 8817bc758ae4b9f76f194e8c191119f56f483ea9..e1a3e77983d80e7c0519e30afbb592311550e794 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_torch.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py @@ -1,7 +1,7 @@ import unittest import torch import yaml -from atat.pytorch.hook_module.wrap_torch import * +from msprobe.pytorch.hook_module.wrap_torch import * class TestWrapTorch(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_vf.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py similarity index 82% rename from debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_vf.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py index 8d57fad6eb623569acfbbffe28dc9decb1ca30b6..98efb4bc5b8a30284fe820124e48af7f487d1c54 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/hook_module/test_wrap_vf.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py @@ -1,6 +1,6 @@ import unittest import torch -from atat.pytorch.hook_module import wrap_vf +from msprobe.pytorch.hook_module import wrap_vf class TestWrapVF(unittest.TestCase): def setUp(self): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py new file mode 100644 index 0000000000000000000000000000000000000000..470390d77b2f233fb97f0916eed6d60c0b0c10ef --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py @@ -0,0 +1,84 @@ +from unittest import TestCase +from unittest.mock import patch, mock_open + +from msprobe.core.common.const import Const +from msprobe.pytorch.pt_config import parse_json_config, parse_task_config + + +class TestPtConfig(TestCase): + def test_parse_json_config(self): + mock_json_data = { + "task": "statistics", + "dump_path": "./dump/", + "rank": [], + "step": [], + "level": "L1", + "seed": 1234, + "statistics": { + "scope": [], + "list": [], + "data_mode": ["all"], + }, + "tensor": { + "file_format": "npy" + } + } + with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \ + patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data): + common_config, task_config = parse_json_config(None, None) + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.data_mode, ["all"]) + + with patch("msprobe.pytorch.pt_config.os.path.join", return_value="/path/config.json"), \ + patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data): + common_config, task_config = parse_json_config(None, Const.TENSOR) + self.assertEqual(common_config.task, Const.STATISTICS) + self.assertEqual(task_config.file_format, "npy") + + def test_parse_task_config(self): + overflow_check_config = { + "overflow_check": { + "overflow_nums": 1, + "check_mode": "all" + } + } + result = parse_task_config(Const.OVERFLOW_CHECK, overflow_check_config) + self.assertEqual(result.overflow_nums, 1) + self.assertEqual(result.check_mode, "all") + + free_benchmark_config = { + "free_benchmark": { + "scope": [], + "list": ["conv2d"], + "fuzz_device": "npu", + "pert_mode": "improve_precision", + "handler_type": "check", + "fuzz_level": "L1", + "fuzz_stage": "forward", + "if_preheat": False, + "preheat_step": 15, + "max_sample": 20 + } + } + result = parse_task_config(Const.FREE_BENCHMARK, free_benchmark_config) + self.assertEqual(result.pert_mode, "improve_precision") + self.assertEqual(result.handler_type, "check") + self.assertEqual(result.preheat_step, 15) + self.assertEqual(result.max_sample, 20) + + run_ut_config = { + "run_ut": { + "white_list": ["conv2d"], + "black_list": ["matmul"], + "error_data_path": '/home/dump_path' + + } + } + with patch('os.path.exists', return_value=True) as mocked_exists: + result = parse_task_config(Const.RUN_UT, run_ut_config) + self.assertEqual(result.white_list, ["conv2d"]) + self.assertEqual(result.black_list, ["matmul"]) + self.assertEqual(result.error_data_path, '/home/dump_path') + mocked_exists.assert_called_once_with('/home/dump_path') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py new file mode 100644 index 0000000000000000000000000000000000000000..c09b6abcb693a048e360ed0d783a0c817c76b06f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py @@ -0,0 +1,59 @@ +import unittest +from unittest.mock import patch, mock_open + +import torch.nn as nn +from msprobe.core.common.utils import Const +from msprobe.pytorch.debugger.debugger_config import DebuggerConfig +from msprobe.pytorch.pt_config import parse_json_config +from msprobe.pytorch.service import Service + + +class TestService(unittest.TestCase): + def setUp(self): + mock_json_data = { + "dump_path": "./dump/", + } + with patch("msprobe.pytorch.pt_config.FileOpen", mock_open(read_data='')), \ + patch("msprobe.pytorch.pt_config.json.load", return_value=mock_json_data): + common_config, task_config = parse_json_config("./config.json", Const.STATISTICS) + self.config = DebuggerConfig(common_config, task_config, Const.STATISTICS, "./ut_dump", "L1") + self.service = Service(self.config) + + def test_start(self): + with patch("msprobe.pytorch.service.get_rank_if_initialized", return_value=0), \ + patch("msprobe.pytorch.service.Service.create_dirs", return_value=None): + self.service.start(None) + self.assertEqual(self.service.current_rank, 0) + + def test_stop_and_step(self): + with patch("msprobe.core.data_dump.data_collector.DataCollector.write_json", return_value=None): + self.service.stop() + self.assertFalse(self.service.switch) + + self.service.step() + self.assertEqual(self.service.current_iter, 1) + + def test_register_hook_new(self): + class TestModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(in_features=8, out_features=4) + + def forward(self, x): + x = self.linear(x) + return x + + self.service.model = TestModule() + self.config.level = "L0" + with patch("msprobe.pytorch.service.logger.info_on_rank_0") as mock_logger, \ + patch("msprobe.pytorch.service.remove_dropout", return_value=None): + self.service.register_hook_new() + self.assertEqual(mock_logger.call_count, 2) + + def test_create_dirs(self): + with patch("msprobe.pytorch.service.Path.mkdir", return_value=None), \ + patch("msprobe.core.common.file_check.FileChecker.common_check", return_value=None), \ + patch("msprobe.core.data_dump.data_collector.DataCollector.update_dump_paths", + return_value=None): + self.service.create_dirs() + self.assertEqual(self.service.dump_iter_dir, "./ut_dump/step0") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9433bc136395773e118810b6154d906670b633f4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py @@ -0,0 +1,99 @@ +import unittest +from unittest.mock import MagicMock, patch +from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, Graph, BaseNode, NodeOp + + +class TestGraphBuilder(unittest.TestCase): + + def setUp(self): + self.construct_path = "step/rank/construct.json" + self.data_path = "step/rank/dump.json" + self.model_name = "TestModel" + self.graph = Graph(self.model_name) + self.construct_dict = { + "Tensor1": "Module1", + "Module1": None + } + self.data_dict = { + "Module1": {"data": "data for Module1"}, + "Tensor1": {"data": "data for Tensor1"} + } + + @patch('msprobe.pytorch.visualization.builder.graph_builder.load_json_file') + @patch('msprobe.pytorch.visualization.builder.graph_builder.load_data_json_file') + def test_build(self, mock_load_data_json_file, mock_load_json_file): + mock_load_data_json_file.return_value = self.data_dict + mock_load_json_file.return_value = self.construct_dict + + graph = GraphBuilder.build(self.construct_path, self.data_path, self.model_name) + self.assertIsNotNone(graph) + self.assertIsInstance(graph, Graph) + self.assertEqual(len(graph.node_map), 3) + + @patch('msprobe.pytorch.visualization.builder.graph_builder.save_json_file') + def test_to_json(self, mock_save_json_file): + GraphBuilder.to_json("step/rank/output.vis", self.graph) + mock_save_json_file.assert_called_once() + + @patch('msprobe.pytorch.visualization.graph.node_op.NodeOp.get_node_op') + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.get_input_output', return_value=([], [])) + def test__init_nodes(self, mock_get_input_output, mock_get_node_op): + GraphBuilder._init_nodes(self.graph, self.construct_dict, self.data_dict) + mock_get_node_op.assert_any_call("Tensor1") + mock_get_node_op.assert_any_call("Module1") + self.assertIs(self.graph.root, self.graph.get_node("TestModel")) + + def test__create_or_get_node(self): + node_op = MagicMock() + data_dict = {"node1": {}} + node = GraphBuilder._create_or_get_node(self.graph, data_dict, node_op, "node1") + self.assertIn("node1", self.graph.node_map) + self.assertEqual(node.input_data, {}) + self.assertEqual(node.output_data, {}) + + def test__handle_backward_upnode_missing(self): + construct_dict = {'Module.module.a.forward.0': 'Module.root.forward.0', 'Module.module.a.backward.0': None, + 'Module.root.forward.0': None, 'Module.root.backward.0': None, + 'Module.module.b.forward.0': 'Module.root.forward.0', + 'Module.module.b.backward.0': 'Module.root.backward.0', 'Module.module.c.backward.0': None} + node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward.0', None) + self.assertEqual(node_id_a, 'Module.root.backward.0') + node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward.0', + 'Module.root.backward.0') + self.assertEqual(node_id_b, 'Module.root.backward.0') + node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward.0', None) + self.assertIsNone(node_id_c) + + def test__collect_apis_between_modules_only_apis(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 1) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.api_collection) + self.assertEqual(len(graph.root.subnodes[0].subnodes), 2) + self.assertEqual(graph.root.subnodes[0].id, 'Apis_Between_Modules.0') + + def test__collect_apis_between_modules_mixed_nodes(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.module, 'Module.a.0'), + BaseNode(NodeOp.module, 'Module.b.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0'), + BaseNode(NodeOp.function_api, 'Tensor.c.0'), BaseNode(NodeOp.module, 'Module.a.1')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 5) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.function_api) + self.assertEqual(graph.root.subnodes[1].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[3].op, NodeOp.api_collection) + self.assertEqual(len(graph.root.subnodes[3].subnodes), 2) + self.assertEqual(graph.root.subnodes[3].id, 'Apis_Between_Modules.0') + + def test__collect_apis_between_modules_only_modules(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.module, 'Module.a.0'), BaseNode(NodeOp.module, 'Module.b.0'), + BaseNode(NodeOp.module, 'Module.a.1')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 3) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[1].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[2].op, NodeOp.module) + self.assertEqual(len(graph.root.subnodes[0].subnodes), 0) + self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..f023128b8c7350b2bcefebe6007dc5ef46133e14 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py @@ -0,0 +1,104 @@ +import unittest +from unittest.mock import patch +from msprobe.pytorch.visualization.builder.msprobe_adapter import ( + get_compare_mode, + run_real_data, + get_input_output, + compare_data, + format_node_data, + compare_node, + _format_decimal_string, + _format_data, + compare_mapping_data +) +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestMsprobeAdapter(unittest.TestCase): + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.task_dumppath_get', return_value=(True, False)) + def test_get_compare_mode_summary(self, mock_task_dumppath_get): + mode = get_compare_mode("dummy_param") + self.assertEqual(mode, GraphConst.SUMMARY_COMPARE) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter._do_multi_process') + def test_run_real_data(self, mock_do_multi_process): + run_real_data("dump_path", "csv_path") + mock_do_multi_process.assert_called_once_with("dump_path", "csv_path") + + def test_get_input_output(self): + node_data = { + 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [5], + 'Max': 2049.0, 'Min': 0.0, 'Mean': 410.20001220703125, 'Norm': 2049.0009765625, + 'requires_grad': False, 'full_op_name': 'Distributed.broadcast.0.forward_input.0'}, + {'type': 'int', 'value': 0}], + 'input_kwargs': {'group': None}, + 'output': [{'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [5], + 'Max': 2049.0, 'Min': 0.0, 'Mean': 410.20001220703125, 'Norm': 2049.0009765625, + 'requires_grad': False, 'full_op_name': 'Distributed.broadcast.0.forward_output.0'}, + {'type': 'int', 'value': 0}, None] + } + node_id = "Distributed.broadcast.0.forward" + input_data, output_data = get_input_output(node_data, node_id) + self.assertIn("Distributed.broadcast.0.forward_output.0", output_data) + self.assertIn("Distributed.broadcast.0.forward_input.0", input_data) + + def test_compare_data(self): + data_dict_list1 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list2 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list3 = {'key1': {'type': 'Type2', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list4 = {} + self.assertTrue(compare_data(data_dict_list1, data_dict_list2)) + self.assertFalse(compare_data(data_dict_list1, data_dict_list3)) + self.assertFalse(compare_data(data_dict_list1, data_dict_list4)) + + def test_format_node_data(self): + data_dict = {'node1': {'data_name': 'data1', 'full_op_name': 'op1'}} + result = format_node_data(data_dict) + self.assertNotIn('data_name', result['node1']) + self.assertNotIn('requires_grad', result['node1']) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.get_accuracy') + def test_compare_node(self, mock_get_accuracy): + node_ids = ["node1", "node2"] + data_dicts = [{'node1': {"input_args": [], "input_kwargs": {}, "output": {}}}, + {'node2': {"input_args": [], "input_kwargs": {}, "output": {}}}] + stack_json_data = {} + result = compare_node(node_ids, data_dicts, stack_json_data, False, False) + mock_get_accuracy.assert_called_once() + self.assertIsInstance(result, list) + + def test__format_decimal_string(self): + s = "0.123456789%" + formatted_s = _format_decimal_string(s) + self.assertIn("0.123457%", formatted_s) + self.assertEqual('0.123457', _format_decimal_string('0.12345678')) + self.assertEqual('-1', _format_decimal_string('-1')) + self.assertEqual('0.0.25698548%', _format_decimal_string('0.0.25698548%')) + + def test__format_data(self): + data_dict = {'value': 0.123456789, 'value1': None, 'value2': "", 'value3': 1.123123123123e-11, + 'value4': torch.inf, 'value5': -1} + _format_data(data_dict) + self.assertEqual(data_dict['value'], '0.123457') + self.assertEqual(data_dict['value1'], 'null') + self.assertEqual(data_dict['value2'], '') + self.assertEqual(data_dict['value3'], '1.123123e-11') + self.assertEqual(data_dict['value4'], 'inf') + self.assertEqual(data_dict['value5'], '-1') + + all_none_dict = {'a': None, 'b': None, 'c': None, 'd': None, 'e': None} + _format_data(all_none_dict) + self.assertEqual({'value': 'null'}, all_none_dict) + + def test_compare_mapping_data(self): + dict1 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}} + dict2 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}} + dict3 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}} + dict4 = {'a': {'shape': [2, 1, 3]}, 'b': {'shape': [1, 2, 3]}} + dict5 = {'a': {'shape': [2, 2, 3]}, 'b': {'shape': [1, 2, 3]}} + dict6 = {'a': {'type': 'str'}} + self.assertTrue(compare_mapping_data(dict1, dict2)) + self.assertTrue(compare_mapping_data(dict1, dict3)) + self.assertTrue(compare_mapping_data(dict1, dict4)) + self.assertFalse(compare_mapping_data(dict1, dict5)) + self.assertTrue(compare_mapping_data(dict1, dict6)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..cb69ac7723d9bff42ef62aa93fd62131790b5fc2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py @@ -0,0 +1,143 @@ +import unittest +from unittest.mock import patch +from unittest.mock import MagicMock +from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator +from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode, NodeOp +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestGraphComparator(unittest.TestCase): + + def setUp(self): + self.graphs = [Graph("model1"), Graph("model2")] + self.data_paths = ["step1/rank/dump.json", "step2/rank/dump.json"] + self.stack_path = "step1/rank/stack.json" + self.output_path = "output/output.vis" + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__parse_param(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + self.comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + self.comparator._parse_param(self.data_paths, self.stack_path, self.output_path) + + self.assertEqual(self.comparator.dump_path_param, { + 'npu_json_path': self.data_paths[0], + 'bench_json_path': self.data_paths[1], + 'stack_json_path': self.stack_path, + 'is_print_compare_log': True + }) + self.assertEqual(self.comparator.output_path, self.output_path) + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_compare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator._compare_nodes = MagicMock() + comparator._postcompare = MagicMock() + + comparator.compare() + + comparator._compare_nodes.assert_called_once() + comparator._postcompare.assert_called_once() + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_add_compare_result_to_node(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + node = MagicMock() + compare_result_list = [("output1", "data1"), ("input1", "data2")] + + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.ma = MagicMock() + comparator.ma.prepare_real_data.return_value = True + + comparator.add_compare_result_to_node(node, compare_result_list) + comparator.ma.prepare_real_data.assert_called_once_with(node) + node.data.update.assert_not_called() + + @patch('msprobe.pytorch.visualization.graph.node_colors.NodeColors.get_node_error_status') + @patch('msprobe.pytorch.visualization.utils.get_csv_df') + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.run_real_data') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__postcompare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode, + mock_run_real_data, mock_get_csv_df, mock_get_node_error_status): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + mock_df = MagicMock() + mock_df.iterrows = MagicMock(return_value=[(None, MagicMock())]) + mock_run_real_data.return_value = mock_df + mock_get_csv_df.return_value = mock_df + mock_get_node_error_status.return_value = True + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.ma = MagicMock() + comparator.ma.is_real_data_compare.return_value = True + comparator._handle_api_collection_index = MagicMock() + comparator.ma.compare_nodes = [MagicMock()] + comparator.ma.parse_result = MagicMock(return_value=(0.9, None)) + + comparator._postcompare() + + comparator._handle_api_collection_index.assert_called_once() + comparator.ma.add_error_key.assert_called() + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__handle_api_collection_index(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0') + api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0') + api1.data = {GraphConst.JSON_INDEX_KEY: 0.9} + api2 = BaseNode(NodeOp.function_api, 'Tensor.b.0') + api2.data = {GraphConst.JSON_INDEX_KEY: 0.6} + apis.subnodes = [api1, api2] + sub_nodes = [BaseNode(NodeOp.module, 'Module.a.0'), apis, BaseNode(NodeOp.module, 'Module.a.1')] + comparator.graph_n.root.subnodes = sub_nodes + comparator._handle_api_collection_index() + self.assertEqual(comparator.graph_n.root.subnodes[1].data.get(GraphConst.JSON_INDEX_KEY), 0.6) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_node') + @patch('msprobe.pytorch.visualization.graph.graph.Graph.match') + @patch('msprobe.pytorch.visualization.graph.graph.Graph.mapping_match') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__compare_nodes(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode, + mock_mapping_match, mock_match, mock_compare_node): + node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') + node_b = BaseNode(NodeOp.function_api, 'Tensor.b.0') + mock_load_data_json_file.return_value = {} + mock_load_json_file.return_value = {} + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + mock_mapping_match.return_value = (node_b, [], []) + mock_compare_node.return_value = ['result'] + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.mapping_config = True + comparator._compare_nodes(node_n) + self.assertEqual(node_n.matched_node_link, ['Tensor.b.0']) + self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + comparator.mapping_config = False + node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') + node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0') + mock_match.return_value = (node_b, []) + comparator._compare_nodes(node_n) + self.assertEqual(node_n.matched_node_link, ['Tensor.a.0']) + self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..da76d8e0d57dc700d45a16f4ea79df1ff7ff1707 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py @@ -0,0 +1,99 @@ +import json +import unittest +from unittest.mock import patch, MagicMock +from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter +from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.pytorch.visualization.utils import GraphConst, ToolTip +from msprobe.core.common.const import CompareConst + + +class TestModeAdapter(unittest.TestCase): + + def setUp(self): + self.node_op = NodeOp.module + self.node_id = "node_1" + self.node = BaseNode(self.node_op, self.node_id) + self.compare_mode = GraphConst.REAL_DATA_COMPARE + self.adapter = ModeAdapter(self.compare_mode) + self.compare_data_dict = [{}, {}] + + def test_add_md5_compare_data(self): + node_data = {'md5_key': 'some_md5_value'} + compare_data_dict = {'md5_key': 'expected_md5_value'} + precision_index = ModeAdapter._add_md5_compare_data(node_data, compare_data_dict) + self.assertEqual(precision_index, 0) + + @patch('msprobe.pytorch.visualization.compare.mode_adapter.ModeAdapter') + def test_parse_result(self, mock_mode_adapter): + mock_mode_adapter._add_summary_compare_data.return_value = 0.5 + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 0.5) + self.assertEqual(other_dict, {}) + + mock_mode_adapter._add_md5_compare_data.return_value = 1 + self.adapter.compare_mode = GraphConst.MD5_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 1) + self.assertEqual(other_dict, {'Result': 'pass'}) + + mock_mode_adapter._add_real_compare_data.return_value = 0.6 + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 0.0) + self.assertEqual(other_dict, {}) + + def test_prepare_real_data(self): + self.adapter.is_real_data_compare = MagicMock(return_value=True) + result = self.adapter.prepare_real_data(self.node) + self.assertTrue(result) + + self.adapter.is_real_data_compare = MagicMock(return_value=False) + result = self.adapter.prepare_real_data(self.node) + self.assertFalse(result) + + def test_compare_mode_methods(self): + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + self.assertTrue(self.adapter.is_summary_compare()) + self.assertFalse(self.adapter.is_md5_compare()) + self.assertFalse(self.adapter.is_real_data_compare()) + + def test_add_csv_data(self): + compare_result_list = ['result1', 'result2'] + self.adapter.add_csv_data(compare_result_list) + self.assertEqual(self.adapter.csv_data, compare_result_list) + + def test_add_error_key(self): + node_data = {'key': {}} + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + self.adapter.add_error_key(node_data) + self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], + [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]) + node_data = {'key': {}} + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + self.adapter.add_error_key(node_data) + self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], + [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]) + + def test_get_tool_tip(self): + self.adapter.compare_mode = GraphConst.MD5_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({'md5': ToolTip.MD5})) + + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({ + CompareConst.MAX_DIFF: ToolTip.MAX_DIFF, + CompareConst.MIN_DIFF: ToolTip.MIN_DIFF, + CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF, + CompareConst.NORM_DIFF: ToolTip.NORM_DIFF})) + + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({ + CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO, + CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.COSINE: ToolTip.COSINE, + CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, + CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR})) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0f12582a36b21000416f2603143b87d0032c65 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py @@ -0,0 +1,76 @@ +import unittest +from unittest.mock import patch +from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestBaseNode(unittest.TestCase): + + def setUp(self): + self.node_op = NodeOp.module + self.node_id = "node_1" + self.up_node = BaseNode(self.node_op, "up_node_1") + self.node = BaseNode(self.node_op, self.node_id, self.up_node) + + def test_init_and_str(self): + self.assertEqual(self.node.op, self.node_op) + self.assertEqual(self.node.id, self.node_id) + self.assertEqual(str(self.node), 'id:\tnode_1') + + def test_eq(self): + other_node = BaseNode(self.node_op, self.node_id, self.up_node) + self.assertEqual(self.node, other_node) + + def test_get_suggestions(self): + self.node.get_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, self.node.suggestions) + + node = BaseNode(NodeOp.function_api, "up_node_1") + node.get_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, node.suggestions) + + def test_set_input_output(self): + input_data = {'input1': 'value1'} + output_data = {'output1': 'value2'} + self.node.set_input_output(input_data, output_data) + self.assertEqual(self.node.input_data, input_data) + self.assertEqual(self.node.output_data, output_data) + + def test_add_upnode(self): + self.node = BaseNode(self.node_op, self.node_id) + new_up_node = BaseNode(self.node_op, "new_up_node_1") + self.node.add_upnode(new_up_node) + self.assertEqual(self.node.upnode, new_up_node) + self.assertIn(self.node, new_up_node.subnodes) + + def test_add_link(self): + other_node = BaseNode(self.node_op, "other_node_1") + ancestors = ['a1', 'a2'] + self.node.add_link(other_node, ancestors) + self.assertEqual(self.node.matched_node_link, ancestors) + self.assertEqual(other_node.matched_node_link, ancestors) + + def test_to_dict(self): + expected_result = { + 'id': self.node_id, + 'node_type': self.node_op.value, + 'data': {}, + 'output_data': {}, + 'input_data': {}, + 'upnode': self.up_node.id, + 'subnodes': [], + 'matched_node_link': [], + 'suggestions': {}, + 'stack_info': [] + } + self.assertEqual(self.node.to_dict(), expected_result) + + def test_get_ancestors(self): + expected_ancestors = ['up_node_1'] + self.assertEqual(self.node.get_ancestors(), expected_ancestors) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_mapping_data') + def test_compare_mapping_node(self, mock_compare_mapping_data): + mock_compare_mapping_data.return_value = True + result = self.node.compare_mapping_node(BaseNode(NodeOp.function_api, "up_node_1")) + self.assertTrue(result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a55bb36e2291caf14a5e27052c5600f4492d0c --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py @@ -0,0 +1,102 @@ +import unittest +from unittest.mock import MagicMock +from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp +from msprobe.pytorch.visualization.graph.base_node import BaseNode +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestGraph(unittest.TestCase): + + def setUp(self): + self.graph = Graph("model_name") + self.node_id = "node_id" + self.node_op = NodeOp.module + + def test_add_node_and_get_node(self): + self.graph.add_node(self.node_op, self.node_id) + node = self.graph.get_node(self.node_id) + self.assertIsNotNone(node) + self.assertIn(self.node_id, self.graph.node_map) + + node_id = "api" + graph = Graph("model_name") + for i in range(0, 9): + graph.add_node(NodeOp.function_api, node_id, id_accumulation=True) + self.assertEqual(len(graph.node_map), 10) + self.assertIn("api.0", graph.node_map) + self.assertIn("api.8", graph.node_map) + self.assertNotIn("api", graph.node_map) + + def test_to_dict(self): + self.graph.add_node(self.node_op, self.node_id) + result = self.graph.to_dict() + self.assertEqual(result[GraphConst.JSON_ROOT_KEY], "model_name") + self.assertIn(self.node_id, result[GraphConst.JSON_NODE_KEY]) + + def test_str(self): + self.graph.add_node(self.node_op, self.node_id) + expected_str = f'{self.node_id}' + self.assertIn(expected_str, str(self.graph)) + + def test_match(self): + graph_a = Graph("model_name_a") + graph_b = Graph("model_name_b") + node_a = BaseNode(self.node_op, self.node_id) + graph_a.add_node(NodeOp.module, "node_id_a") + graph_b.add_node(NodeOp.module, "node_id_b") + matched_node, ancestors = Graph.match(graph_a, node_a, graph_b) + self.assertIsNone(matched_node) + self.assertEqual(ancestors, []) + + graph_b.add_node(NodeOp.module, "node_id_a") + graph_a.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a")) + graph_b.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a")) + matched_node, ancestors = Graph.match(graph_a, graph_a.get_node("node_id_a_1"), graph_b) + self.assertIsNotNone(matched_node) + self.assertEqual(ancestors, ['node_id_a']) + + def test_dfs(self): + graph = Graph("model_name") + graph.add_node(NodeOp.module, "node_a") + graph.add_node(NodeOp.module, "node_b") + node_a = BaseNode(self.node_op, self.node_id) + result = {} + graph.dfs(node_a, result) + self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, + 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], + 'matched_node_link': [], 'suggestions': {}}}) + + def test_split_nodes_by_micro_step(self): + nodes = [BaseNode(NodeOp.module, 'a.0'), BaseNode(NodeOp.module, 'b.0'), + BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.1'), + BaseNode(NodeOp.module, 'b.1'), BaseNode(NodeOp.api_collection, 'apis.1')] + result = Graph.split_nodes_by_micro_step(nodes) + self.assertEqual(len(result), 2) + self.assertEqual(len(result[0]), 3) + + def test_paging_by_micro_step(self): + nodes = [BaseNode(NodeOp.module, 'a.0'), BaseNode(NodeOp.module, 'b.0'), + BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.1'), + BaseNode(NodeOp.module, 'b.1'), BaseNode(NodeOp.api_collection, 'apis.1')] + + graph = Graph('Model1') + graph.root.subnodes = nodes + graph_other = Graph('Model2') + graph_other.root.subnodes = nodes + + result = graph.paging_by_micro_step(graph_other) + self.assertEqual(result, 2) + self.assertEqual(graph.root.subnodes[0].micro_step_id, 0) + self.assertEqual(graph_other.root.subnodes[0].micro_step_id, 0) + + def test_mapping_match(self): + mapping_config = MagicMock() + graph_a = Graph("model_name_a") + graph_b = Graph("model_name_b") + graph_a.add_node(NodeOp.module, "a1", BaseNode(NodeOp.module, "root")) + graph_b.add_node(NodeOp.module, "b1", BaseNode(NodeOp.module, "root")) + mapping_config.get_mapping_string.return_value = "b1" + node_b, ancestors_n, ancestors_b = Graph.mapping_match(graph_a.get_node("a1"), graph_b, mapping_config) + self.assertIsNotNone(node_b) + self.assertEqual(ancestors_n, ["root"]) + self.assertEqual(ancestors_b, ["root"]) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..0be05586b150cfd39670535cc3015925d9e2f44e --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py @@ -0,0 +1,67 @@ +import unittest +from msprobe.pytorch.visualization.graph.node_colors import NodeColors, SUMMARY_DESCRIPTION, REAL_DATA_DESCRIPTION, \ + NOT_MATCHED +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestNodeColors(unittest.TestCase): + + def test_get_info_by_mode(self): + node_yellow = NodeColors.YELLOW_1 + summary_info = node_yellow.get_info_by_mode(GraphConst.SUMMARY_COMPARE) + self.assertEqual(summary_info[GraphConst.VALUE], [0, 0.2]) + self.assertEqual(summary_info[GraphConst.DESCRIPTION], SUMMARY_DESCRIPTION) + node_grey = NodeColors.GREY + md5_info = node_grey.get_info_by_mode(GraphConst.MD5_COMPARE) + self.assertEqual(md5_info[GraphConst.VALUE], []) + self.assertEqual(md5_info[GraphConst.DESCRIPTION], NOT_MATCHED) + node_red = NodeColors.RED + real_info = node_red.get_info_by_mode(GraphConst.REAL_DATA_COMPARE) + self.assertEqual(real_info[GraphConst.VALUE], [0.2, 1]) + self.assertEqual(real_info[GraphConst.DESCRIPTION], REAL_DATA_DESCRIPTION) + none_info = node_yellow.get_info_by_mode("non_existent_mode") + self.assertEqual(none_info, {}) + + def test_get_node_colors(self): + # 测试获取所有颜色信息的函数 + mode = GraphConst.SUMMARY_COMPARE + colors_info = NodeColors.get_node_colors(mode) + self.assertIn("#FFFCF3", colors_info) + self.assertIn("#FFEDBE", colors_info) + self.assertIn("#FFDC7F", colors_info) + self.assertIn("#FFC62E", colors_info) + self.assertIn("#E32020", colors_info) + self.assertIn("#C7C7C7", colors_info) + + # 确保返回的字典具有正确的描述和值范围 + expected_value_range = [0, 0.2] + expected_description = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|" + self.assertEqual(colors_info["#FFFCF3"][GraphConst.VALUE], expected_value_range) + self.assertEqual(colors_info["#FFFCF3"][GraphConst.DESCRIPTION], expected_description) + + mode = GraphConst.MD5_COMPARE + colors_info = NodeColors.get_node_colors(mode) + self.assertIn("#FFFCF3", colors_info) + self.assertIn("#C7C7C7", colors_info) + self.assertNotIn("#FFDC7F", colors_info) + + expected_value_range = [1, 1] + expected_description = "与标杆相比, 此节点所有输入输出的md5值相同" + self.assertEqual(colors_info["#FFFCF3"][GraphConst.VALUE], expected_value_range) + self.assertEqual(colors_info["#FFFCF3"][GraphConst.DESCRIPTION], expected_description) + + def test_get_node_error_status(self): + # 测试错误状态判断功能 + mode = GraphConst.SUMMARY_COMPARE + value0 = 0 + value1 = 0.25 + value2 = 0.55 + value3 = 111 + self.assertFalse(NodeColors.get_node_error_status(mode, value0)) + self.assertFalse(NodeColors.get_node_error_status(mode, value1)) + self.assertTrue(NodeColors.get_node_error_status(mode, value2)) + self.assertTrue(NodeColors.get_node_error_status(mode, value3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1a340ac8b3c7144a9e07485c93e289a950eee8c7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py @@ -0,0 +1,28 @@ +import unittest +from msprobe.pytorch.visualization.graph.node_op import NodeOp + + +class TestNodeOp(unittest.TestCase): + + def test_get_node_op_valid(self): + node_name = "ModuleTest" + self.assertEqual(NodeOp.get_node_op(node_name), NodeOp.module) + + def test_get_node_op_invalid(self): + node_name = "InvalidNodeName" + with self.assertRaises(Exception): + NodeOp.get_node_op(node_name) + + def test_get_node_op_all(self): + test_cases = [ + ("ModuleTest", NodeOp.module), + ("TensorTest", NodeOp.function_api), + ("TorchTest", NodeOp.function_api), + ("FunctionalTest", NodeOp.function_api), + ("NPUTest", NodeOp.function_api), + ("VFTest", NodeOp.function_api), + ("DistributedTest", NodeOp.function_api), + ("AtenTest", NodeOp.function_api) + ] + for node_name, expected_op in test_cases: + self.assertEqual(NodeOp.get_node_op(node_name), expected_op) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b2f85ebf872aae4b3377842ac899824da5877f9 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml @@ -0,0 +1,2 @@ +- vision_model: "language_model.vision_encoder" +- vision_projection: "language_model.projection" \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py new file mode 100644 index 0000000000000000000000000000000000000000..010a4f686198ef1127299fcad1b9a5abf6505d6b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py @@ -0,0 +1,52 @@ +import os +import unittest +from msprobe.pytorch.visualization.mapping_config import MappingConfig + + +class TestMappingConfig(unittest.TestCase): + + def setUp(self): + self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") + + def test_validate(self): + with self.assertRaises(ValueError): + MappingConfig.validate(123, "some value") + with self.assertRaises(ValueError): + MappingConfig.validate("some key", 456) + self.assertEqual(MappingConfig.validate("key", "value"), "value") + + def test_convert_to_regex(self): + regex = MappingConfig.convert_to_regex("hello{world}") + self.assertEqual(regex, ".*hello\\{world\\}.*") + + def test_replace_parts(self): + result = MappingConfig._replace_parts('hello world', 'world', 'everyone') + self.assertEqual(result, 'hello everyone') + result = MappingConfig._replace_parts('radio_model.layers.0.input_norm', 'radio_model.layers.{}.input_norm', + 'radio_model.transformer.layers.{}.input_layernorm') + self.assertEqual(result, 'radio_model.transformer.layers.0.input_layernorm') + + def test_get_mapping_string(self): + mc = MappingConfig(self.yaml_path) + mc.classify_config = { + 'category1': [('category1.key1', 'replacement1')], + 'category2': [('category2.key1', 'replacement2')] + } + result = mc.get_mapping_string("some category1.key1 text") + self.assertEqual(result, "some replacement1 text") + + def test_long_string(self): + long_string = "x" * (MappingConfig.MAX_STRING_LEN + 1) + mc = MappingConfig(self.yaml_path) + result = mc.get_mapping_string(long_string) + self.assertEqual(result, long_string) + + def test__classify_and_sort_keys(self): + mc = MappingConfig(self.yaml_path) + result = mc._classify_and_sort_keys() + self.assertEqual(result, {'vision_model': [('vision_model', 'language_model.vision_encoder')], + 'vision_projection': [('vision_projection', 'language_model.projection')]}) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56493235cd63d97b8a2beca0358c59ed16a154cf --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py @@ -0,0 +1,27 @@ +import os +import unittest +from msprobe.pytorch.visualization.utils import (load_json_file, load_data_json_file, str2float) + + +class TestMappingConfig(unittest.TestCase): + + def setUp(self): + self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") + + def test_load_json_file(self): + result = load_json_file(self.yaml_path) + self.assertEqual(result, {}) + + def test_load_data_json_file(self): + result = load_data_json_file(self.yaml_path) + self.assertEqual(result, {}) + + def test_str2float(self): + result = str2float('23.4%') + self.assertAlmostEqual(result, 0.234) + result = str2float('2.3.4%') + self.assertAlmostEqual(result, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/atat/test/resources/advisor.txt b/debug/accuracy_tools/msprobe/test/resources/advisor.txt similarity index 100% rename from debug/accuracy_tools/atat/test/resources/advisor.txt rename to debug/accuracy_tools/msprobe/test/resources/advisor.txt diff --git a/debug/accuracy_tools/atat/test/resources/compare_result_20230703104808.csv b/debug/accuracy_tools/msprobe/test/resources/compare_result_20230703104808.csv similarity index 100% rename from debug/accuracy_tools/atat/test/resources/compare_result_20230703104808.csv rename to debug/accuracy_tools/msprobe/test/resources/compare_result_20230703104808.csv diff --git a/debug/accuracy_tools/atat/test/resources/compare_result_without_accuracy.csv b/debug/accuracy_tools/msprobe/test/resources/compare_result_without_accuracy.csv similarity index 100% rename from debug/accuracy_tools/atat/test/resources/compare_result_without_accuracy.csv rename to debug/accuracy_tools/msprobe/test/resources/compare_result_without_accuracy.csv diff --git a/debug/accuracy_tools/atat/test/resources/config.yaml b/debug/accuracy_tools/msprobe/test/resources/config.yaml similarity index 100% rename from debug/accuracy_tools/atat/test/resources/config.yaml rename to debug/accuracy_tools/msprobe/test/resources/config.yaml diff --git a/debug/accuracy_tools/atat/test/resources/npu_test.pkl b/debug/accuracy_tools/msprobe/test/resources/npu_test.pkl similarity index 100% rename from debug/accuracy_tools/atat/test/resources/npu_test.pkl rename to debug/accuracy_tools/msprobe/test/resources/npu_test.pkl diff --git a/debug/accuracy_tools/atat/test/run_test.sh b/debug/accuracy_tools/msprobe/test/run_test.sh similarity index 100% rename from debug/accuracy_tools/atat/test/run_test.sh rename to debug/accuracy_tools/msprobe/test/run_test.sh diff --git a/debug/accuracy_tools/atat/test/run_ut.py b/debug/accuracy_tools/msprobe/test/run_ut.py similarity index 97% rename from debug/accuracy_tools/atat/test/run_ut.py rename to debug/accuracy_tools/msprobe/test/run_ut.py index 7c593c14abca82f39050276255316693a47c6fc9..8ea81ccca719952bdb8a6603b998902df94a53fb 100644 --- a/debug/accuracy_tools/atat/test/run_ut.py +++ b/debug/accuracy_tools/msprobe/test/run_ut.py @@ -3,7 +3,7 @@ import shutil import subprocess import sys -from atat.core.common.log import logger +from msprobe.core.common.log import logger def run_ut(): diff --git a/debug/accuracy_tools/atat/test/test_module_processer.py b/debug/accuracy_tools/msprobe/test/test_module_processer.py similarity index 95% rename from debug/accuracy_tools/atat/test/test_module_processer.py rename to debug/accuracy_tools/msprobe/test/test_module_processer.py index 89ee299f66fc1c14e9fcb666bd6d21e0ca2f17a9..448c35f0554551884dc690a71aef6bc8141e9a39 100644 --- a/debug/accuracy_tools/atat/test/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/test_module_processer.py @@ -1,6 +1,6 @@ import unittest -from atat.pytorch.module_processer import ModuleProcesser -from atat.pytorch.common.utils import Const +from msprobe.pytorch.module_processer import ModuleProcesser +from msprobe.pytorch.common.utils import Const import torch diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index 3568e3a47c2fa2a42b016fb787683565e4c3ab0b..58b6881b41aa44098996a24877aeae43a66dcced 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -14,27 +14,54 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import subprocess +import setuptools +from setuptools import setup -from setuptools import setup, find_packages + +repository_url = "https://gitee.com/ascend/msit.git" +target_dir = "./msprobe/msit" +branch_name = "master" + +subprocess.check_call(["git", "submodule", "add", "-b", branch_name, repository_url, target_dir]) +subprocess.check_call(["git", "submodule", "init"]) +subprocess.check_call(["git", "submodule", "update"]) + + +EXCLUDE_PKGS = [ + "api_accuracy_checker*", + "grad_tool*", + "monitor*", + "ptdbg_ascend*", + "msprobe.test*", +] setup( - name='ascend_training_accuracy_tools', - version='1.0', + name='mindstudio-probe', + version='1.0.4', description='This is a pytorch precision comparison tools', long_description='This is a pytorch precision comparison tools, include ptdbg and api accuracy checker', - packages=find_packages(), + packages=setuptools.find_namespace_packages(exclude=EXCLUDE_PKGS, include=["msprobe", "msprobe*"]), install_requires=[ "wheel", - "numpy", - "pandas >= 1.3.5", + "einops", + "numpy < 2.0", + "pandas >= 1.3.5, < 2.1", "pyyaml", "rich", "tqdm", - "openpyxl" + "openpyxl", + "pyOpenSSL", + "twisted", + "matplotlib", + "protobuf <= 3.20.1", + "service_identity", + "onnx" ], include_package_data=True, ext_modules=[], zip_safe=False, entry_points={ - 'console_scripts' : ['atat=atat.atat:main'], - },) \ No newline at end of file + 'console_scripts': ['msprobe=msprobe.msprobe:main', + 'add_torchair_compare_path=msprobe.pytorch.torchair_compare.msit_path_add:create_symlink'], + },) diff --git a/plugins/mindstudio-insight-plugins/CMakeLists.txt b/plugins/mindstudio-insight-plugins/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..48c8919553685570400ef939df31cfcf9eb7d120 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/CMakeLists.txt @@ -0,0 +1,13 @@ +cmake_minimum_required(VERSION 3.20) +project(MindStudio_Board) +set(PROJECT_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/mind_expression.cmake) +add_subdirectory(proto) + +#if (${CMAKE_BUILD_TYPE} MATCHES "Debug") +message(STATUS "Open tools generate") +add_subdirectory(tools/httpServer) +#endif () +# add_subdirectory(Scalar) +add_subdirectory(Histogram) +add_subdirectory(plugin_core) \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Histogram/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b1df7818cbf8dc8e82d6f6d9326782072f2cd1b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/CMakeLists.txt @@ -0,0 +1,2 @@ +set(HISTOGRAM_PROJECT_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +add_subdirectory(server) \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Histogram/server/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d10f8d8a0674a039a224a5e546ebd8d842aa97a --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.20) +project(Histogram) + +set(CMAKE_BUILD_TYPE Debug CACHE STRING "Build type" FORCE) + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) + message(FATAL_ERROR "GCC version must be 7.3.0 and above, but found ${CMAKE_CXX_COMPILER_VERSION}") + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.4.0) + message(WARNING "GCC version ${CMAKE_CXX_COMPILER_VERSION} is greater than 11.4.0, may cause unknown problems.") + endif() +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) + +set(HOME_DIR ${PROJECT_SOURCE_DIR}) +set(EXECUTABLE_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/plugins) +set(LIBRARY_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/plugins) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-all -D_FORTIFY_SOURCE=2 -O0 -g -ftrapv -fstack-protector-strong -fPIE -fPIC") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-all -D_FORTIFY_SOURCE=2 -O0 -g -ftrapv -fstack-protector-strong -fPIE -fPIC") +if (${CMAKE_BUILD_TYPE} MATCHES "Debug") + message(STATUS "Enable debug symbol table, change optimization level to 0") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -O0") +endif () + +set(CMAKE_SKIP_RPATH true) +if (CMAKE_SYSTEM_NAME MATCHES "windows") + if ((NOT CMAKE_BUILD_TYPE MATCHES "Debug") AND (NOT CMAKE_BUILD_TYPE MATCHES "PROFILE")) + message(STATUS "Build type = ${CMAKE_BUILD_TYPE}, static = enable.") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--nxcompat -Wl,--dynamicbase -s -pie -Wincompatible-pointer-types") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wl,--nxcompat -Wl,--dynamicbase -s -pie -Wincompatible-pointer-types") + add_link_options(-static) + endif() +elseif() + if ((NOT CMAKE_BUILD_TYPE MATCHES "Debug") AND (NOT CMAKE_BUILD_TYPE MATCHES "PROFILE")) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -pie -Wl,-z,now") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -pie -Wl,-z,now") + endif() +endif() + + +add_subdirectory(src) \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Histogram/server/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b054d2e9d6eed28708488a957edb05a5e168e0a3 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/CMakeLists.txt @@ -0,0 +1,84 @@ +set(SRC_HOME_DIR ${HOME_DIR}/src) +aux_source_directory(defs HISTOGRAM_SRC_LIST) +aux_source_directory(handler HISTOGRAM_SRC_LIST) +aux_source_directory(histogramManager HISTOGRAM_SRC_LIST) +aux_source_directory(histoParser HISTOGRAM_SRC_LIST) +aux_source_directory(viewManager HISTOGRAM_SRC_LIST) +aux_source_directory(plugin HISTOGRAM_SRC_LIST) +aux_source_directory(${PROJECT_ROOT_DIR}/proto HISTOGRAM_SRC_LIST) + +set(LOG_SRC ${PROJECT_ROOT_DIR}/plugin_core/src/Logger.cpp) + +list(APPEND ${PROJECT_NAME}_SRC + ${PROTO_SRC} + ${HISTOGRAM_SRC_LIST} + ${LOG_SRC}) +include_directories(${SRC_HOME_DIR} + ${SRC_HOME_DIR}/parser + ${PROJECT_ROOT_DIR}/proto + ${PROJECT_ROOT_DIR} + ${SRC_HOME_DIR}/defs + ${SRC_HOME_DIR}/handler + ${SRC_HOME_DIR}/histoParser + ${SRC_HOME_DIR}/viewManager + ${SRC_HOME_DIR}/plugin + +) +set(LIBRARY_OUTPUT_PATH ${LIBRARY_OUTPUT_PATH}/${PROJECT_NAME}) +add_library(${PROJECT_NAME} SHARED ${${PROJECT_NAME}_SRC} + utils/FileUtils.h + handler/GetHistoDataRequestHandler.h + handler/GetHistoDataRequestHandler.cpp + handler/GetNewFilesRequestHandler.h + handler/GetNewFilesRequestHandler.cpp + handler/AddImportFileRequestHandler.cpp + handler/AddImportFileRequestHandler.h + handler/ImportRequestHandler.h + viewManager/FileMonitor.h + histoParser/MindsporeParser.h + histoParser/MindsporeParser.cpp) +target_include_directories(${PROJECT_NAME} PRIVATE ${${PROJECT_NAME}_H}) +target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/rapidjson/include/rapidjson) +if (${CMAKE_BUILD_TYPE} MATCHES "Debug") + message(STATUS "Open Debug Info") + target_compile_options(${PROJECT_NAME} PRIVATE -O0 -g) +endif () +target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/plugin_core/include) +target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so) +target_link_libraries(${PROJECT_NAME} PRIVATE mindboard::protobuf) +set(CMAKE_CXX_VISIBILITY_PRESET default) +if (${CMAKE_SYSTEM_NAME} MATCHES "Linux") + target_link_libraries(${PROJECT_NAME} PRIVATE stdc++fs) +endif () + +#-------- test -------- + +if (ENABLE_TESTCASES OR ENABLE_CPP_ST) + enable_testing() + aux_source_directory(test TEST_SRC) + aux_source_directory(${PROJECT_ROOT_DIR}/plugin_core/src TEST_SRC) + list(APPEND ${PROJECT_NAME}_TEST_SRC + ${${PROJECT_NAME}_SRC} + ${TEST_SRC}) + list(APPEND ${PROJECT_NAME}_TEST_H + ${${PROJECT_NAME}_H}) + add_executable(${PROJECT_NAME}_test ${${PROJECT_NAME}_TEST_SRC}) + message(STATUS "Dir:${PROJECT_ROOT_DIR}") + target_include_directories(${PROJECT_NAME}_test PUBLIC ${${PROJECT_NAME}_TEST_H}) + target_include_directories(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}/plugin_core/include) + target_include_directories(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}) + target_include_directories(${PROJECT_NAME}_test PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/plugin) + # if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(STATUS "Open Debug ${CMAKE_BUILD_TYPE}") + target_link_libraries(${PROJECT_NAME}_test PRIVATE stdc++fs) + # endif () + target_link_libraries(${PROJECT_NAME}_test PRIVATE mindboard::protobuf) + target_link_libraries(${PROJECT_NAME}_test PUBLIC gtest_main gmock_main) + target_link_libraries(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so) + target_link_libraries(${PROJECT_NAME}_test PRIVATE mindboard::protobuf) + if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(${PROJECT_NAME}_test PRIVATE -O0 -g) + endif () + + set_target_properties(${PROJECT_NAME}_test PROPERTIES EXCLUDE_FROM_ALL true) +endif () \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/HistoConceptDefs.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/HistoConceptDefs.h new file mode 100644 index 0000000000000000000000000000000000000000..223d65333b0758357e56ccc563a2ea4ec0ce2a72 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/HistoConceptDefs.h @@ -0,0 +1,140 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef CONCEPTDEFS_H +#define CONCEPTDEFS_H +#include +#include + +#include "rapidjson.h" +#include "document.h" +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif +#include + +#include "proto/event.pb.h" +#include "proto/mindspore_summary.pb.h" + +namespace Insight::Histogram { + using value_t = rapidjson::Value; + using document_t = rapidjson::Document; + + enum class ParseDataType { + MindSpore_Summary = 0, + TF_EVENT = 1, + Unknown = 2 + }; + + enum ErrCode : int { + OK = 0, + INVALID_REQUEST_JSON, + REQUEST_INVALID_PARAM, + INVALID_PATH + }; + + // 直方图上每一条线,y轴的高度和属于哪一个step,由HistogramProto转化而来 + struct HistogramLine { + int64_t step_{0}; + std::vector bucket_; + std::vector bucketLimit_; + HistogramLine() = default; + HistogramLine(int64_t step, tensorboard::HistogramProto hisProto) { + step_ = step; + std::copy(hisProto.bucket().begin(), hisProto.bucket().end(), std::back_inserter(bucket_)); + std::copy(hisProto.bucket_limit().begin(), hisProto.bucket_limit().end(), + std::back_inserter(bucketLimit_)); + } + + HistogramLine(int64_t step, mindspore::irpb::Summary_Histogram histogram) { + step_ = step; + for (auto& bucket : histogram.buckets()) { + bucketLimit_.push_back(bucket.left() + bucket.width()); + bucket_.push_back(bucket.count()); + } + } + + value_t CreatHistoLineValue(document_t::AllocatorType& allocator) { + value_t histoGramLine(rapidjson::kObjectType); + // bucket_ 数据 + value_t bucket(rapidjson::kArrayType); + for (const auto& count: bucket_) { + bucket.PushBack(count, allocator); + } + value_t bucketLimit(rapidjson::kArrayType); + for (const auto& limit: bucketLimit_) { + bucketLimit.PushBack(limit, allocator); + } + histoGramLine.AddMember("step", step_, allocator); + histoGramLine.AddMember("bucket", bucket, allocator); + histoGramLine.AddMember("bucketLimit", bucketLimit, allocator); + return histoGramLine; + } + }; + + struct HistogramGraph { + std::vector hsitogramLines_; + // 记录上一次getdata请求的时候hsitogramLines_长度,初始未请求时是0 + uint64_t length{0}; + // 下采样的缓存数据 + std::vector downsampleCatchLines; + HistogramGraph() = default; + explicit HistogramGraph(const std::vector& hsitogramLines) : hsitogramLines_(hsitogramLines) {} + void AddValue(const HistogramLine& line) { + hsitogramLines_.push_back(line); + } + void MergeData(const HistogramGraph& graph) { + hsitogramLines_.insert(hsitogramLines_.begin(),graph.hsitogramLines_.begin(), graph.hsitogramLines_.end()); + } + + // 对数据进行下采样 + void SetdDownsampleCatchLines() { + constexpr int reservoirNum = 50; + // hsitogramLines_ 小于 length 或者抽样数量比存储的数量多 + if (hsitogramLines_.size() <= length || reservoirNum > hsitogramLines_.size()) { + return; + } + + std::vector indices(hsitogramLines_.size() - 1); + std::iota(indices.begin(), indices.end(), 0); + + // 使用固定的随机数种子0进行随机打乱 + std::default_random_engine generator(0); + std::shuffle(indices.begin(), indices.end(), generator); + + // 选择前k - 1 个元素,并添加最后一个元素 + indices.resize(reservoirNum - 1); + std::sort(indices.begin(), indices.end()); + indices.push_back(hsitogramLines_.size() - 1); + + std::vector result; + for (int idx : indices) { + downsampleCatchLines.push_back(hsitogramLines_[idx]); + } + } + + void UpdateCatchData() { + // 更新缓存的下采样数据 + SetdDownsampleCatchLines(); + // 更新长度到最新的长度 + length = hsitogramLines_.size(); + } + + value_t CreatHistogramValue(document_t::AllocatorType& allocator) { + // 获取json说明接收到了getdata的请求,所以更新下采样数据 + UpdateCatchData(); + // 转成json格式 + value_t hsitogramLinesValue(rapidjson::kArrayType); + for (auto& hsitogramLine: downsampleCatchLines) { + value_t histogram = hsitogramLine.CreatHistoLineValue(allocator); + hsitogramLinesValue.PushBack(histogram, allocator); + } + return hsitogramLinesValue; + } + }; +} +#endif //CONCEPTDEFS_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/RequestDef.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/RequestDef.h new file mode 100644 index 0000000000000000000000000000000000000000..55a63a98021ddce904a247b27c4d9cabe83f62b8 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/RequestDef.h @@ -0,0 +1,99 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef REQUEST_DEF_H +#define REQUEST_DEF_H + +#include +#include +#include "rapidjson.h" +#include "document.h" + +#include "HistoConceptDefs.h" + +namespace Insight::Histogram { + using document_t = rapidjson::Document; + using value_t = rapidjson::Value; + using size_type_t =rapidjson::SizeType; + struct ImportRequest { + std::vector rootPathList; + ImportRequest() = default; + explicit ImportRequest(std::string_view request) { + std::string jsonString(request); + document_t document; + if (document.Parse(jsonString.c_str()).HasParseError()) { + return; + } + if (document.HasMember("rootPaths") && document["rootPaths"].IsArray()) { + const value_t& pathList = document["rootPaths"]; + for (size_type_t i = 0; i < pathList.Size(); i++) { + rootPathList.emplace_back(pathList[i].GetString()); + } + } + } + }; + + // 获取新增文件的接口,返回rootpath下面路径列表,没有入参 filrlist + struct GetNewFilePathRequest { + }; + + // 更新file和tag的映射关系, 返回的是file tagsd对应关系,这部分一定是没有解析过的文件最新导入进来 + struct AddImportFileRequest { + std::vector filepathList; //确定的具体文件 + AddImportFileRequest() = default; + explicit AddImportFileRequest(std::string_view request) { + std::string jsonString(request); + document_t document; + if (document.Parse(jsonString.c_str()).HasParseError()) { + return; + } + if (document.HasMember("filePathList") && document["filePathList"].IsArray()) { + const value_t& pathList = document["filePathList"]; + for (size_type_t i = 0; i < pathList.Size(); i++) { + filepathList.emplace_back(pathList[i].GetString()); + } + } + } + }; + + /* + * 前端发送的json请求格式如下 + * { + * "filePathTotags": [ + * {"filepath": "/home/text/path1", "tag": "tagXXX1"}, + * {"filepath": "/home/text/path2", "tag": "tagXXX2"}, + * ] + * } + **/ + // 获取具体某个图的时候,会限定是哪个文件的哪个tag的图。一个文件可能有多个 + struct GetHistoDataRequest { + std::map> filePathTotags; + GetHistoDataRequest() = default; + explicit GetHistoDataRequest(std::string_view request) + { + std::string jsonString(request); + document_t document; + if (document.Parse(jsonString.c_str()).HasParseError()) { + return; + } + if (document.HasMember("filePathToTags") && document["filePathToTags"].IsArray()) { + const value_t& pathList = document["filePathToTags"]; + for (size_type_t i = 0; i < pathList.Size(); i++) { + const value_t& pathToTag = pathList[i]; + if (!pathToTag.HasMember("filePath") || !pathToTag.HasMember("tag")) { + continue; + } + if (!pathToTag["filePath"].IsString() || !pathToTag["tag"].IsString()) { + continue; + } + std::string path = pathToTag["filePath"].GetString(); + if (filePathTotags.find(path) == filePathTotags.end()) { + filePathTotags[path] = std::set(); + } + filePathTotags[path].insert(pathToTag["tag"].GetString()); + } + } + } + }; +} +#endif //REQUEST_DEF_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/ResponseDef.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/ResponseDef.h new file mode 100644 index 0000000000000000000000000000000000000000..26cfca53d88d0310402a54ffd6d158d899d6d632 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/defs/ResponseDef.h @@ -0,0 +1,141 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#ifndef RESPONSE_DEF_H +#define RESPONSE_DEF_H + +#include +#include +#include +#include +#include "rapidjson.h" +#include "document.h" + +#include "defs/HistoConceptDefs.h" + +namespace Insight::Histogram { + +using allocator_t = rapidjson::Document::AllocatorType; +using document_t = rapidjson::Document; +using value_t = rapidjson::Value; +using size_type_t =rapidjson::SizeType; + +struct ResponseDef { + std::string message_; + bool result_{true}; + int errCode_{0}; + ResponseDef() = default; + ResponseDef(const std::string & message, bool result, int errCode) : + message_(message), result_(result), errCode_(errCode) {} + + virtual value_t CreatDataValue(document_t::AllocatorType& allocator) { + value_t body(rapidjson::kObjectType); + return body; + } + + document_t CreatBasicJson() { + document_t document; + document.SetObject(); + document_t::AllocatorType& allocator = document.GetAllocator(); + document.AddMember("errCode", errCode_, allocator); + document.AddMember("msg", value_t().SetString(message_.c_str(), allocator), allocator); + document.AddMember("result", result_, allocator); + + value_t body = CreatDataValue(allocator); + document.AddMember("body", body, allocator); + return document; + } + + std::string ToJsonString() + { + document_t document = CreatBasicJson(); + // json 转换成string格式 + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + + std::string jsonString = buffer.GetString(); + return jsonString; + } + virtual ~ResponseDef() = default; +}; + +struct ImportResponse : public ResponseDef { + ImportResponse() = default; + explicit ImportResponse(const std::map> &filePathToTags) : filePathToTags_(filePathToTags) {} + ImportResponse(const std::string & message, bool result, int errCode, const std::map> &filePathToTags) : + ResponseDef(message, result, errCode), filePathToTags_(filePathToTags) {} + std::map> filePathToTags_; + // 数据转换成json格式,返回给前端 + value_t CreatDataValue(document_t::AllocatorType& allocator) override + { + value_t data(rapidjson::kArrayType); + for (const auto& [filePath, tags] : filePathToTags_) { + // filepath是一个json对象 + value_t entry(rapidjson::kObjectType); + entry.AddMember("filePath", value_t().SetString(filePath.c_str(), allocator), allocator); + // tags是一个集合,新建一个数组对象 + value_t tagList(rapidjson::kArrayType); + for (const auto& tag : tags) { + tagList.PushBack(value_t().SetString(tag.c_str(), allocator), allocator); + } + entry.AddMember("tagList", tagList, allocator); + data.PushBack(entry, allocator); + } + value_t body(rapidjson::kObjectType); + body.AddMember("data", data, allocator); + return body; + } +}; + +// 获取采集路径下的新增文件,request的时候不需要参数 +struct GetNewFilesResponse : public ResponseDef { + std::vector filePaths_; + GetNewFilesResponse() = default; + explicit GetNewFilesResponse(const std::vector &filePaths) : filePaths_(filePaths) {} + GetNewFilesResponse(const std::string & message, bool result, int errCode, + const std::vector &filePaths) : ResponseDef(message, result, errCode), filePaths_(filePaths) {} + // 数据转换成json格式,返回给前端 + value_t CreatDataValue(document_t::AllocatorType& allocator) override + { + value_t data(rapidjson::kArrayType); + for (const auto& filePath: filePaths_) { + data.PushBack(value_t().SetString(filePath.c_str(), allocator), allocator); + } + value_t body(rapidjson::kObjectType); + body.AddMember("data", data, allocator); + return body; + } +}; + +// 返回数据按照文件 - tag - 直方图的形式 +struct GetHistoDataResponse : public ResponseDef{ + std::map> filePathToInfo_; + GetHistoDataResponse() = default; + explicit GetHistoDataResponse(const std::map> &filePathToInfo) : + filePathToInfo_(filePathToInfo) {} + GetHistoDataResponse(const std::string & message, bool result, int errCode, + const std::map> &filePathToInfo) : + ResponseDef(message, result, errCode), filePathToInfo_(filePathToInfo) {} + // 数据转换成json格式,返回给前端 + value_t CreatDataValue(document_t::AllocatorType& allocator) override + { + value_t data(rapidjson::kArrayType); + for (const auto& filePair: filePathToInfo_) { + for (const auto& tagPair: filePair.second) { + value_t filetagHisto(rapidjson::kObjectType); + filetagHisto.AddMember("filePath", value_t().SetString(filePair.first.c_str(), allocator), allocator); + filetagHisto.AddMember("tag", value_t().SetString(tagPair.first.c_str(), allocator), allocator); + HistogramGraph histogram = tagPair.second; + filetagHisto.AddMember("histogramGraph", histogram.CreatHistogramValue(allocator), allocator); + + data.PushBack(filetagHisto, allocator); + } + } + return data; + } +}; + +} +#endif //RESPONSE_DEF_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..54cd3a84c7ba320defa2d6fa9931b04cd2454634 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.cpp @@ -0,0 +1,21 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#include "AddImportFileRequestHandler.h" +#include "histoParser/FileParser.h" +#include "defs/ResponseDef.h" +#include "utils/FileUtils.h" +#include "defs/RequestDef.h" + +namespace Insight::Histogram::handler { +using namespace Insight::Histogram::Parser; + bool AddImportFileRequestHandler::HandleRequest(std::string_view data, std::string &resultStr) + { + AddImportFileRequest req(data); + // 解析并获取文件和tag映射数据 + ImportResponse rsp(GetTagsByFilePath(req.filepathList)); + resultStr = rsp.ToJsonString(); + return true; + } +} // Histogram diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..a25a59dd66115333134448892498fe300563ad3d --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/AddImportFileRequestHandler.h @@ -0,0 +1,18 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#ifndef UPDATEIMPORTFILEREQUESTHANDLER_H +#define UPDATEIMPORTFILEREQUESTHANDLER_H + +#include +#include "ImportRequestHandler.h" + +namespace Insight::Histogram::handler { + class AddImportFileRequestHandler : public ImportRequestHandler { + bool HandleRequest(std::string_view data, std::string &resultStr) override; + }; +} // Histogram +#endif //UPDATEIMPORTFILEREQUESTHANDLER_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6c33a363daac6b3096edc364fb1d4258bc37a07 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.cpp @@ -0,0 +1,65 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#include "GetHistoDataRequestHandler.h" +#include "histogramViewManager.h" +#include "histoParser/ParserFactory.h" +#include "histoParser/FileParser.h" +#include "defs/ResponseDef.h" +#include "utils/FileUtils.h" +#include "RequestDef.h" + +namespace Insight::Histogram::handler { +using namespace Insight::Histogram::Parser; + + bool GetHistoDataRequestHandler::run(std::string_view data, std::string &result) { + return HandleRequest(data, result); + } + + bool GetHistoDataRequestHandler::HandleRequest(std::string_view data, std::string &resultStr) + { + GetHistoDataRequest req(data); + std::map> pathToTagsMap = req.filePathTotags; + + HistogramViewManager& viewManager = HistogramViewManager::getInstance(); + + std::map> filePathToInfo; + // 遍历请求命令里的每一个文件 + ParserFactory parserFactory; + for (auto & pathToTags : pathToTagsMap) { + std::string filePath = pathToTags.first; + if (pathToTags.second.empty()) { + // 这个文件路径下想要的tag 图为空,则不再往下走节省时间 + continue; + } + uint64_t offset = viewManager.GetReaderOffsetByFilePath(filePath); + std::shared_ptr parser = parserFactory.CreateFileParse(filePath); + if (!parser) { + continue; + } + // 从上次解析的偏移量开始解析新的数据 + parser->ParseData(filePath, offset); + // 获取到最新数据交给infomanager追加到老数据里并做更新 + viewManager.AppendNewFileInfo(filePath, offset, parser->GetTags(), parser->GetTagToHistoGraph()); + // 获取合并之后的新数据 + std::map newData = viewManager.GetHistoInfoByFilePath(filePath); + filePathToInfo[filePath] = FilterHistoDataByTag(pathToTags.second, viewManager.GetHistoInfoByFilePath(filePath)); + }; + GetHistoDataResponse rsp(filePathToInfo); + resultStr = rsp.ToJsonString(); + return true; + } + + // 根据前端请求返回需要的tag数据 + std::map GetHistoDataRequestHandler::FilterHistoDataByTag(std::set tags, std::map histoGraphs) + { + std::map result; + for (auto & tag : tags) { + if (histoGraphs.find(tag) != histoGraphs.end()) { + result[tag] = histoGraphs[tag]; + } + } + return result; + } +} // Histogram diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..0c20a1ac00a159efb1f2e7e824b5a5e86c42ec86 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetHistoDataRequestHandler.h @@ -0,0 +1,29 @@ +// +// Created by admin on 2024/12/11. +// + +#ifndef GET_HISTO_DATA_REQUEST_HANDLER_H +#define GET_HISTO_DATA_REQUEST_HANDLER_H + +#include +#include +#include +#include +#include "iostream" +#include "ApiHandler.h" +#include "defs/RequestDef.h" + +namespace Insight::Histogram::handler { + + class GetHistoDataRequestHandler : public Dic::Core::PostHandler{ + public: + bool run(std::string_view data, std::string &result) override; + private: + bool HandleRequest(std::string_view data, std::string &resultStr); + std::map FilterHistoDataByTag(std::set tags, + std::map histoGraphs); + }; + +} // Histogram + +#endif //GET_HISTO_DATA_REQUEST_HANDLER_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetNewFilesRequestHandler.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetNewFilesRequestHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3aee5a9dca8eddf8476dc0dc675817e185f9be8b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetNewFilesRequestHandler.cpp @@ -0,0 +1,21 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "GetNewFilesRequestHandler.h" +#include "histogramViewManager.h" +#include "defs/ResponseDef.h" + +namespace Insight::Histogram::handler { + bool GetNewFilesRequestHandler::run(std::string_view data, std::string &result) { + return HandleRequest(data, result); + } + + bool GetNewFilesRequestHandler::HandleRequest(std::string_view data, std::string &resultStr) + { + // 获取通过文件监控得到的新增文件 + GetNewFilesResponse rsp(HistogramViewManager::getInstance().GetNewFiles()); + // 每次请求过后清空新增文件列表,重新记录从本地请求结束到下次请求开始这个时间段的请求 + resultStr = rsp.ToJsonString(); + return true; + } +} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetNewFilesRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetNewFilesRequestHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..91eec5f3f55c5a636746c96ff5b05c5c1ef763b4 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/GetNewFilesRequestHandler.h @@ -0,0 +1,25 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#ifndef GET_NEWFILES_REQUEST_HANDLER_H +#define GET_NEWFILES_REQUEST_HANDLER_H + +#include +#include +#include +#include "iostream" +#include "ApiHandler.h" + +namespace Insight::Histogram::handler { + + class GetNewFilesRequestHandler : public Dic::Core::PostHandler{ + public: + bool run(std::string_view data, std::string &result) override; + private: + bool HandleRequest(std::string_view data, std::string &resultStr); + }; + +} // Histogram + +#endif //GET_NEWFILES_REQUEST_HANDLER_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportProjectRequestHandler.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportProjectRequestHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ac6286f6ec94a2caa91dff8f8a677dc93ae5817 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportProjectRequestHandler.cpp @@ -0,0 +1,36 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#include "ImportProjectRequestHandler.h" +#include "histoParser/ParserFactory.h" +#include "defs/ResponseDef.h" +#include "utils/FileUtils.h" +#include "RequestDef.h" + +namespace Insight::Histogram::handler { + using namespace Insight::Histogram::Parser; + + bool ImportProjectRequestHandler::HandleRequest(std::string_view data, std::string &resultStr) + { + // 工程最开始进来的时候先清空所有数据,停掉文件监控 + HistogramViewManager::getInstance().Reset(); + ImportRequest req(data); + std::vector pathList = req.rootPathList; + if (!fs::is_directory(pathList[0])) { + ResponseDef response("ERROR: The path is not a directory.", false, 0); + resultStr = response.ToJsonString(); + return false; + } + // 根目录只有一个,每次只查看一个工程下的数据,不支持集群数据 + std::vector fileList = util::GetListFilesInDirectory(req.rootPathList[0]); + + // 获取到当前文件夹下的所有文件之后,开启文件监控,记录新增文件 + HistogramViewManager::getInstance().StartFileWatch(req.rootPathList[0]); + std::map> fileToTags = GetTagsByFilePath(fileList); + ImportResponse rsp(fileToTags); + resultStr = rsp.ToJsonString(); + return true; + } + +} // Histogram diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportProjectRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportProjectRequestHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..ea4ec5777c349b3cdce04455105b5c37eb9f2ed9 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportProjectRequestHandler.h @@ -0,0 +1,17 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef IMPORT_FILE_REQUEST_HANDLER_H +#define IMPORT_FILE_REQUEST_HANDLER_H + +#include +#include "ImportRequestHandler.h" + +namespace Insight::Histogram::handler { + +class ImportProjectRequestHandler : public ImportRequestHandler { + bool HandleRequest(std::string_view data, std::string &resultStr) override; +}; + +} // Histogram +#endif //IMPORT_FILE_REQUEST_HANDLER_H \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportRequestHandler.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportRequestHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..71297184ed69aeb67c3a724bf089d18be60e100a --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/handler/ImportRequestHandler.h @@ -0,0 +1,50 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#ifndef IMPORTREQUESTHANDLER_H +#define IMPORTREQUESTHANDLER_H + +#include +#include +#include +#include "ApiHandler.h" +#include "histogramViewManager.h" +#include "histoParser/ParserFactory.h" +#include "histoParser/FileParser.h" + +namespace Insight::Histogram::handler { + using namespace Insight::Histogram::Parser; + class ImportRequestHandler : public Dic::Core::PostHandler{ + public: + bool run(std::string_view data, std::string &result) override { + return HandleRequest(data, result); + } + protected: + virtual bool HandleRequest(std::string_view data, std::string &resultStr) = 0; + + std::map> GetTagsByFilePath(const std::vector &fileList) + { + std::map> filepathToTags; + ParserFactory parserFactory; + for (const auto& filepath :fileList) { + // 逐个解析每个文件的数据 + uint64_t offset = 0; + std::shared_ptr parser = parserFactory.CreateFileParse(filepath); + if (!parser) { + continue; + } + parser->ParseData(filepath, offset); + std::set tags = parser->GetTags(); + // 获取文件到tag的对应关系 + filepathToTags[filepath] = tags; + // 已经解析完的数据存储在viewManager单例里, tag 和 具体图的数据都要存储进去 + HistogramViewManager::getInstance().SetViewFileList(filepath, offset, tags, parser->GetTagToHistoGraph()); + } + // 把filepathToTags 塞给response转换成json返回给前端 + return filepathToTags; + } + }; + +} // Histogram +#endif //IMPORTREQUESTHANDLER_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/FileParser.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/FileParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77b348f649a7196e33373c2a39264ac83a87db2e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/FileParser.cpp @@ -0,0 +1,92 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "FileParser.h" +#include "Logger.h" +#include "defs/HistoConceptDefs.h" + +using namespace Insight::Histogram::Parser; +using namespace Insight::Histogram; +using namespace Insight; + +bool FileParser::CheckFilePathVaild(const std::string &filePath) { + if (filePath.empty()) { + LOG(LogRank::Error) << "File path is empty"; + return false; + } + if (!fs::exists(filePath)) { + LOG(LogRank::Error) << "File path does not exist"; + return false; + } + auto readPermission = fs::status(filePath).permissions() & fs::perms::owner_read; + if (readPermission == fs::perms::none) { + LOG(LogRank::Error) << "File not permit to read"; + return false; + } + return true; +} + +std::ifstream FileParser::OpenFileSafe(const std::string &filePath) { + std::ifstream file; + file.setstate(std::ios::badbit); + if (!CheckFilePathVaild(filePath)) { + LOG(LogRank::Error) << "Open file failed"; + return file; + } + file = std::ifstream(filePath, std::ios::in | std::ios::binary); + return file; +} + +bool FileParser::ParseData(const std::string &filePath, uint64_t &offset) { + std::ifstream file = OpenFileSafe(filePath); + if (!file.is_open()) { + LOG(LogRank::Warning) << "Parse data faild, open file error"; + return false; + } + file.seekg(static_cast(offset), std::ios::beg); + std::string recordStr; + while (file && ReadRecord(recordStr, file)) { + if (!ParseRecord(std::move(recordStr))) { + break; + } + std::streampos index = file.tellg(); + if (index != -1) { + offset = index; + } + } + // 读到二进制数据之后转换成tag - graph 的映射方便前端画图 + return true; +} + +bool FileParser::ReadCheckSumRecord(std::ifstream &input, std::vector &buffer, size_t size) { + if (!input) { + return false; + } + if (size > std::numeric_limits::max()) { + LOG(LogRank::Error) << "Read data exceed limit"; + return false; + } + + buffer.clear(); + buffer.resize(size + 1); + input.read(buffer.data(), static_cast(size)); + if (input.gcount() != size) { + return false; + } + uint32_t ccrc = 0; + input.read(reinterpret_cast(&ccrc), sizeof(uint32_t)); + if (input.gcount() != sizeof(uint32_t)) { + return false; + } + return true; +} + +std::set FileParser::GetTags() +{ + return tags; +} + +std::map FileParser::GetTagToHistoGraph() +{ + return tagTohistoGraph; +} diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/FileParser.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/FileParser.h new file mode 100644 index 0000000000000000000000000000000000000000..2042bbcbeaf4518add830d9886379f38c000281f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/FileParser.h @@ -0,0 +1,56 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef FILEPARSER_H +#define FILEPARSER_H + +#include +#include +#include +#include +#include +#include "defs/HistoConceptDefs.h" + +namespace Insight::Histogram::Parser { +/** + * @brief This base class of data file parser + */ +class FileParser { +public: + FileParser() : type_(ParseDataType::Unknown) {} + + /** + * @brief open file stream safely + * @param filePath + * @return + */ + std::ifstream OpenFileSafe(const std::string &filePath); + /** + * + * @param filePath + * @param[in/out] offset file read offset, update after parsed + * @return + */ + + bool ParseData(const std::string &filePath, uint64_t &offset); + std::set GetTags(); + std::map GetTagToHistoGraph(); + + virtual bool ReadRecord(std::string &recordStr, std::ifstream &input) = 0; + static bool ReadCheckSumRecord(std::ifstream &input, std::vector &buffer, size_t size); + virtual bool ParseRecord(std::string &&record) = 0; + virtual ~FileParser() = default; + +protected: + std::set tags; + std::map tagTohistoGraph; + +private: + bool CheckFilePathVaild(const std::string &filePath); + +public: + ParseDataType type_; +}; +} + +#endif //FILEPARSER_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/MindsporeParser.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/MindsporeParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3e25b68a555ea488e68ecfbb34daf5320eb7ccb1 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/MindsporeParser.cpp @@ -0,0 +1,59 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "proto/mindspore_summary.pb.h" +#include "Logger.h" +#include "MindsporeParser.h" + +using namespace Insight::Histogram::Parser; +using namespace Insight::Histogram; +using namespace Insight; + +bool MindsporeParser::ReadRecord(std::string &eventStr, std::ifstream &input) { + /* + * The structure of tf event file: + * | data length | CRC sum | pb data | CRC sum | + * | uint64_t | uint32_t | .... | uint32_t | + */ + std::vector record; + if (!ReadCheckSumRecord(input, record, sizeof(uint64_t))) { + return false; + } + uint64_t length = 0; + memcpy(&length, record.data(), sizeof(uint64_t)); + if (!ReadCheckSumRecord(input, record, length)) { + return false; + } + eventStr = std::string(record.data(), length); + return true; +} + +bool MindsporeParser::ParseRecord(std::string &&record) { + mindspore::irpb::Event event; + if (!event.ParseFromString(record)) { + LOG(LogRank::Error) << "Can't convert str to tensorflow event"; + return false; + } + uint64_t step = event.step(); + if (!event.has_summary()) { + return true; + } + for (const auto& value : event.summary().value()) { + if (!value.has_histogram()) { + continue; + } + // 存储这个文件的相关tag + tags.insert(value.tag()); + // 获取原始的histo 并存储这些原始数据到tagTohistolines里 + const auto& histogram = value.histogram(); + // 由tag 到直方图的每一条线的对应关系存储数据 + // 如果这个tag到图的对应关系没有,就新建一个 + if (tagTohistoGraph.find(value.tag()) == tagTohistoGraph.end()) { + tagTohistoGraph[value.tag()] = HistogramGraph(); + } + HistogramLine line(step, histogram); + tagTohistoGraph[value.tag()].AddValue(line); + } + return true; +} + diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/MindsporeParser.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/MindsporeParser.h new file mode 100644 index 0000000000000000000000000000000000000000..6f5209ce42ec8b64de1659dc0863f5b4ddd54392 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/MindsporeParser.h @@ -0,0 +1,30 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MINDSPOREPARSER_H +#define MINDSPOREPARSER_H + +#include "FileParser.h" +#include "defs/HistoConceptDefs.h" + +namespace Insight::Histogram::Parser { + class MindsporeParser final : public FileParser { + public: + MindsporeParser() { + type_ = ParseDataType::MindSpore_Summary; + } + + ~MindsporeParser() override = default; + + bool ReadRecord(std::string &eventStr, std::ifstream &input) override; + private: + /** + * @brief check whether contains scalar value + * @param event MindSpore_Summary event object + * @return true for success + */ + bool ParseRecord(std::string &&record) override; + }; +} + +#endif //MINDSPOREPARSER_H \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/ParserFactory.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/ParserFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..de53be94d92e76ebe4143ee1cc878631ed085b61 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/ParserFactory.h @@ -0,0 +1,60 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef PARSERFACTORY_H +#define PARSERFACTORY_H + +#include +#include +#include "defs/HistoConceptDefs.h" +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif +#include "TFEventParser.h" +#include "MindsporeParser.h" + +namespace Insight::Histogram::Parser { +using namespace Insight::Histogram; + +class ParserFactory { +public: + std::shared_ptr CreateFileParse(std::string_view filePath) { + ParseDataType type = GetFileType(filePath); + if (parsers_.find(type) == parsers_.end()) { + return nullptr; + } + return parsers_[type]; + } +private: + inline ParseDataType GetFileType(std::string_view filePath) { + for (const auto &[k, v]: fileTypeMap_) { + std::regex regex(k); + std::smatch match; + std::string fileName = fs::path(filePath).filename().string(); + if (std::regex_search(fileName, match, regex)) { + return v; + } + } + return ParseDataType::Unknown; + } + + std::map > parsers_ = { + {ParseDataType::TF_EVENT, std::make_shared()}, + {ParseDataType::MindSpore_Summary, std::make_shared()}, + {ParseDataType::Unknown, nullptr} + }; + + std::map fileTypeMap_ = { + {R"(out.tfevent)", ParseDataType::TF_EVENT}, + {R"(out.events.summary)", ParseDataType::MindSpore_Summary} + }; +}; +} + +#endif //PARSERFACTORY_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/TFEventParser.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/TFEventParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5b8e2a3dc639c5d6653d651404e7b50db1c4d60 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/TFEventParser.cpp @@ -0,0 +1,59 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include + +#include "TFEventParser.h" +#include "proto/event.pb.h" +#include "Logger.h" + +using namespace Insight::Histogram::Parser; +using namespace Insight::Histogram; +using namespace Insight; + +bool TFEventParser::ReadRecord(std::string &eventStr, std::ifstream &input) { + /* + * The structure of tf event file: + * | data length | CRC sum | pb data | CRC sum | + * | uint64_t | uint32_t | .... | uint32_t | + */ + std::vector record; + if (!ReadCheckSumRecord(input, record, sizeof(uint64_t))) { + return false; + } + uint64_t length = 0; + memcpy(&length, record.data(), sizeof(uint64_t)); + if (!ReadCheckSumRecord(input, record, length)) { + return false; + } + eventStr = std::string(record.data(), length); + return true; +} + +bool TFEventParser::ParseRecord(std::string &&record) { + tensorboard::Event event; + if (!event.ParseFromString(record)) { + LOG(LogRank::Error) << "Can't convert str to tensorflow event"; + return false; + } + uint64_t step = event.step(); + if (!event.has_summary()) { + return true; + } + for (const auto& value : event.summary().value()) { + if (!value.has_histo()) { + continue; + } + // 存储这个文件的相关tag + tags.insert(value.tag()); + // 获取原始的histo 并存储这些原始数据到tagTohistolines里 + const auto& histogram = value.histo(); + // 由tag 到直方图的每一条线的对应关系存储数据 + if (tagTohistoGraph.find(value.tag()) == tagTohistoGraph.end()) { + tagTohistoGraph[value.tag()] = HistogramGraph(); + } + HistogramLine line(step, histogram); + tagTohistoGraph[value.tag()].AddValue(line); + } + return true; +} diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/TFEventParser.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/TFEventParser.h new file mode 100644 index 0000000000000000000000000000000000000000..de0dab93bab730d08b4e8165e72e2a264de88f80 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/histoParser/TFEventParser.h @@ -0,0 +1,30 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef TFEVENTPARSER_H +#define TFEVENTPARSER_H + +#include "FileParser.h" +#include "defs/HistoConceptDefs.h" + +namespace Insight::Histogram::Parser { +class TFEventParser final : public FileParser { +public: + TFEventParser() { + type_ = ParseDataType::TF_EVENT; + } + + ~TFEventParser() override = default; + + bool ReadRecord(std::string &eventStr, std::ifstream &input) override; +private: + /** + * @brief check whether contains scalar value + * @param event tf event object + * @return true for success + */ + bool ParseRecord(std::string &&record) override; +}; +} + +#endif //TFEVENTPARSER_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/plugin/HistoVisualPlugin.cpp b/plugins/mindstudio-insight-plugins/Histogram/server/src/plugin/HistoVisualPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cbac8b766ea3d49ba0386f94b3d5aa677fb65843 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/plugin/HistoVisualPlugin.cpp @@ -0,0 +1,53 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ + +#include "document.h" +#include "HistoVisualPlugin.h" + +#include "AddImportFileRequestHandler.h" +#include "GetHistoDataRequestHandler.h" +#include "GetNewFilesRequestHandler.h" +#include "PluginsManager.h" +#include "handler/ImportProjectRequestHandler.h" +#include "utils/HistogramProtocolUtil.h" + +using namespace Insight; +namespace Insight::Histogram { + + using namespace Insight::Histogram::Protocol; + using json_t = rapidjson::Value; + using document_t = rapidjson::Document; + + HistoVisualPlugin::HistoVisualPlugin() : Dic::Core::BasePlugin("HistoVisually") { + handlers_.emplace("ImportProject", std::make_shared()); + handlers_.emplace("GetNewFiles", std::make_shared()); + handlers_.emplace("GetHistoData", std::make_shared()); + handlers_.emplace("AddImportFile", std::make_shared()); + } + + std::map> HistoVisualPlugin::GetAllHandlers() { + std::map> res(handlers_.begin(), handlers_.end()); + return res; + } + + std::vector HistoVisualPlugin::GetModuleConfig() { + std::vector res; + document_t moduleConfig; + moduleConfig.SetObject(); + auto &allocator = moduleConfig.GetAllocator(); + AddJsonMember(moduleConfig, "name", "Histogram", allocator); + AddJsonMember(moduleConfig, "requestName", "Histogram", allocator); + json_t attributes(rapidjson::kObjectType); + AddJsonMember(attributes, "src", "./plugins/Histogram/index.html", allocator); + AddJsonMember(moduleConfig, "attributes", attributes, allocator); + AddJsonMember(moduleConfig, "isDefault", true, allocator); + AddJsonMember(moduleConfig, "isCluster", true, allocator); + AddJsonMember(moduleConfig, "isCompute", true, allocator); + AddJsonMember(moduleConfig, "isJupyter", true, allocator); + res.push_back(DumpJsonToStr(moduleConfig)); + return res; + } +} + +Dic::Core::PluginRegister pluginRegister(std::move(std::make_unique())); diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/plugin/HistoVisualPlugin.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/plugin/HistoVisualPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..b77d15c3846682094dee4615ee64505309ef2f8d --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/plugin/HistoVisualPlugin.h @@ -0,0 +1,28 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_SCALARVISUALPLUGIN_H +#define BOARD_SCALARVISUALPLUGIN_H + +#include "BasePlugin.h" + +namespace Insight::Histogram { +using namespace Dic::Core; + +class HistoVisualPlugin : public BasePlugin { +public: + HistoVisualPlugin(); + + ~HistoVisualPlugin() override = default; + + std::map> GetAllHandlers() override; + + std::vector GetModuleConfig() override; + + +private: + std::map> handlers_; +}; +} + +#endif //BOARD_SCALARVISUALPLUGIN_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/utils/FileUtils.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/utils/FileUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..77640f845be97b087404f0734a618b54fa207d5e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/utils/FileUtils.h @@ -0,0 +1,38 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef FILEUTILS_H +#define FILEUTILS_H + +#include +#include +#include "iostream" +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif + +namespace Insight::Histogram::util { + + inline static std::vector GetListFilesInDirectory(const fs::path& dirPath) { + std::vector fileList; + if (!fs::exists(dirPath) || !fs::is_directory(dirPath)) { + std::cerr << "Error: path does not exist or is not a directory." << std::endl; + return fileList; + } + for (const auto& entry : fs::recursive_directory_iterator(dirPath)) { + const auto& path = entry.path(); + if (fs::is_regular_file(path)) { + fileList.push_back(path.string()); + } + } + return fileList; + } +} + +#endif //FILEUTILS_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/utils/HistogramProtocolUtil.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/utils/HistogramProtocolUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..6767337ace6d6b91783a85c24fd6990b475e34fa --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/utils/HistogramProtocolUtil.h @@ -0,0 +1,154 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_PLUGINS_HISTOGRAMVISUALLY_SRC_PROTOCOLUTIL_PROTOCOLUTIL_H_ +#define BOARD_PLUGINS_HISTOGRAMVISUALLY_SRC_PROTOCOLUTIL_PROTOCOLUTIL_H_ + +#include "document.h" +#include "writer.h" + +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +#include +#include +#include +#include +#include + +namespace Insight::Histogram::Protocol { +using json = rapidjson::Value; +using document_t = rapidjson::Document; + +template +static inline std::optional TryParseJson(std::string_view jsonStr, std::string &error) { + document_t doc; + doc.Parse(jsonStr.data(), jsonStr.length()); + if (doc.HasParseError()) { + constexpr size_t printErrorSize = 10; + auto offset = doc.GetErrorOffset(); + auto start = offset >= printErrorSize ? offset - printErrorSize : 0; + error = "Error code:" + std::to_string(doc.GetParseError()) + ", str:" + + std::string(jsonStr.substr(start, offset - start + printErrorSize)); + return std::nullopt; + } + return std::move(doc); +} + +static inline document_t ParseJsonToStr(std::string_view jsonStr) { + document_t doc; + doc.Parse(jsonStr.data(), jsonStr.length()); + return doc; +} + +static inline std::string DumpJsonToStr(const json &document) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + return {buffer.GetString(), buffer.GetSize()}; +} + +template +static inline void AddJsonMember(json &dst, + std::string_view key, + T &&value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), std::forward(value), allocator); +} + +static inline void AddJsonMember(json &dst, + std::string_view key, + const std::string &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), + rapidjson::Value().SetString(value.data(), value.size(), allocator), + allocator); +} + +static inline void AddJsonMember(json &dst, + std::string_view key, + std::string &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), + rapidjson::Value().SetString(value.data(), value.size(), allocator), + allocator); +} + +static inline void AddJsonMember(json &dst, + std::string_view key, + std::string &&value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), + rapidjson::Value().SetString(value.data(), value.size(), allocator), + allocator); +} + +template +static inline void AddJsonMember(json &dst, + std::string_view key, + std::vector &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + json temp(rapidjson::kArrayType); + for (const T &item: value) { + if constexpr (std::is_same_v) { + temp.PushBack(json().SetString(item.data(), item.size(), allocator), allocator); + } else { + temp.PushBack(std::forward(item), allocator); + } + } + AddJsonMember(dst, key, temp, allocator); +} + +template +static inline void AddJsonMember(json &dst, + std::string_view key, + std::set &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + json temp(rapidjson::kArrayType); + for (const T &item: value) { + if constexpr (std::is_same_v) { + temp.PushBack(json().SetString(item.data(), item.size(), allocator), allocator); + } else { + temp.PushBack(std::forward(item), allocator); + } + } + AddJsonMember(dst, key, temp, allocator); +} + +static inline void SetResponseError(int errCode, const std::string &errMsg, std::string &resultStr) { + document_t result = ParseJsonToStr(resultStr); + auto &allocator = result.GetAllocator(); + result["errCode"].SetInt(errCode); + result["msg"].SetString(errMsg.c_str(), errMsg.size(), allocator); + result["result"].SetBool(false); + resultStr = DumpJsonToStr(result); +} + +static inline std::string GetBasicResponse() { + return R"({"body":{}, "msg":"", "errCode":0, "result":true})"; +} + +static inline std::string GetReadableFileName(std::string_view path) { + if (path.empty()) { + return ""; + } + auto curPath = fs::path(path); + auto fileName = curPath.filename(); + if (!curPath.has_parent_path()) { + return std::string(path); + } + auto parentPath = curPath.parent_path(); + if (!parentPath.has_filename()) { + return fileName.string(); + } + auto parentDir = parentPath.filename(); + auto res = parentDir / fileName; + return res.string(); +} +} +#endif //BOARD_PLUGINS_HISTOGRAMVISUALLY_SRC_PROTOCOLUTIL_PROTOCOLUTIL_H_ diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/FileMonitor.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/FileMonitor.h new file mode 100644 index 0000000000000000000000000000000000000000..081128e4cf4fc3c3c29902bff67d1801327a8934 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/FileMonitor.h @@ -0,0 +1,93 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef FILEMONITOR_H +#define FILEMONITOR_H +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif +#include +#include + +namespace Insight::Histogram { + class FileMonitor { + public: + FileMonitor() : monitoring(false) {} + + void StartMonitoring(const std::string& path) { + directoryPath = path; + monitoring = true; + monitorThread = std::thread(&FileMonitor::MonitorDirectory, this); + } + + void StopMonitoring() { + { + std::lock_guard lock(mtx); + newFiles.clear(); + monitoring = false; + } + cv.notify_one(); + if (monitorThread.joinable()) { + monitorThread.join(); + } + } + + std::vector GetNewFiles() { + std::lock_guard lock(mtx); + std::vector files =newFiles; + newFiles.clear(); + return files; + } + + private: + void MonitorDirectory() { + std::unordered_set existingFiles = GetAllFiles(directoryPath); + while (true) { + std::this_thread::sleep_for(std::chrono::milliseconds(10000)); + std::unordered_set currentFiles = GetAllFiles(directoryPath); + { + std::lock_guard lock(mtx); + if (!monitoring) { + break; + } + for (const auto& file : currentFiles) { + if (existingFiles.find(file) == existingFiles.end()) { + newFiles.push_back(file); + } + } + } + + existingFiles = std::move(currentFiles); + } + } + + std::unordered_set GetAllFiles(const std::string& path) { + std::unordered_set files; + for (const auto& entry : fs::recursive_directory_iterator(path)) { + if (fs::is_regular_file(entry.path())) { + files.insert(entry.path().string()); + } + } + return files; + } + + std::string directoryPath; + std::vector newFiles; + std::thread monitorThread; + std::mutex mtx; + std::condition_variable cv; + bool monitoring; + }; +} +#endif //FILEMONITOR_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/ViewFile.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/ViewFile.h new file mode 100644 index 0000000000000000000000000000000000000000..d6d8fd1b5b6c0d9380296e3b37e9d1cd59ec0d94 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/ViewFile.h @@ -0,0 +1,33 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef VIEW_FILE_GRAPH_H +#define VIEW_FILE_GRAPH_H + +#include +#include +#include + +#include "defs/HistoConceptDefs.h" + +using namespace Insight; +namespace Insight::Histogram { +using namespace Insight::Histogram; + +/* +* 每个要在前端呈现的界面基本内容 +* 包括文件完整路径 +* tag和图的对应关系 +* 以及上一次读取的文件位置 +*/ +class ViewFile { +public: +private: + std::string filePath; + std::streampos lastreadPos; + // 该文件里tag 和 图的对应关系 + std::unordered_map> tagToHistograms; +}; + +} // Histogram +#endif //VIEW_FILE_GRAPH_H diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/histogramViewGraphs.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/histogramViewGraphs.h new file mode 100644 index 0000000000000000000000000000000000000000..90c5a76c3e6c24bc5a453f3e1c77fc7854967d22 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/histogramViewGraphs.h @@ -0,0 +1,34 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_HISTOGRAM_VIEW_GRAPH_H +#define BOARD_HISTOGRAM_VIEW_GRAPH_H + +#include +#include +#include + +#include "defs/HistoConceptDefs.h" + +using namespace Insight; +namespace Insight::Histogram { +using namespace Insight::Histogram; + +/* +* 这个类是关于界面呈现的所有数据的综合存储,包括从文件到tag 以及从tag到每一张图的对应 +* 当文件或者路径有变更时,由 管理类进行更新 +*/ +class HistogramViewGraphs { +public: +private: + void AddFilePathToTag(); + void AddTagToHistogram(); + void Reset(); + + // 在界面上呈现的数据映射关系 文件 - tag - 图,调用parser的时候更新 + std::unordered_map> filePathToTag; + std::unordered_map> tagToHistograms; +}; + +} // Histogram +#endif //BOARD_HISTOGRAM_VIEW_GRAPH_H \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/histogramViewManager.h b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/histogramViewManager.h new file mode 100644 index 0000000000000000000000000000000000000000..9adc7f3a7d053883b4f0f626eb59e85cf6d23c84 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Histogram/server/src/viewManager/histogramViewManager.h @@ -0,0 +1,133 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_HISTOGRAM_VIEW_MANAGER_H +#define BOARD_HISTOGRAM_VIEW_MANAGER_H + +#include +#include +#include +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif + +#include "ViewFile.h" +#include "FileMonitor.h" + +namespace Insight::Histogram { +// 在界面上呈现的文件数据 +struct ViewFileInfo { + // 界面上已解析文件的偏移量,主要用来在文件有变更时做追加 + uint64_t offset_{0}; + // 该文件的tag数据 + std::set tags_; + // 解析之后的直方图数据,tag - HistogramLines的映射形式,有新增则追加 + std::map tagTohistoGraph_; + + ViewFileInfo() = default; + ViewFileInfo(uint64_t offset, std::set tags, + std::map tagTohistoGraph) : + offset_(offset), tagTohistoGraph_(tagTohistoGraph), tags_(tags) {} + + void Update(uint64_t offset, std::set tags, + std::map tagTohistoGraph) { + // 偏移量需要更新 + offset_ = offset; + // 其他数据需要和老数据合并 + tags_.insert(tags.begin(), tags.end()); + for (auto tagPair : tagTohistoGraph) { + // 没有就重新加一个 + if (tagTohistoGraph_.count(tagPair.first) == 0) { + tagTohistoGraph_[tagPair.first] = tagPair.second; + } else { + tagTohistoGraph_[tagPair.first].MergeData(tagPair.second); + } + } + } +}; + +class HistogramViewManager { +public: + // 开启文件监控,监控新增文件以及有变更的文件 + HistogramViewManager(const HistogramViewManager&) = delete; + HistogramViewManager& operator=(const HistogramViewManager&) = delete; + HistogramViewManager(HistogramViewManager &&) = delete; + static HistogramViewManager& getInstance() + { + static HistogramViewManager instance; + return instance; + } + + void SetViewFileList(const std::string &filePath, uint64_t offset, std::set tags, + std::map tagTohistoGraph) + { + ViewFileInfo fileInfo = ViewFileInfo(offset, tags, tagTohistoGraph); + ViewFileList[filePath] = fileInfo; + } + // 根据前端请求返回需要的tag数据 + void AppendNewFileInfo(const std::string &filePath, uint64_t offset, std::set tags, + std::map newGraphs) + { + if (newGraphs.empty()) { + return; + } + // 之前没存储过这个文件,就直接存储进来 + if (ViewFileList.find(filePath) == ViewFileList.end()) { + SetViewFileList(filePath, offset, tags, newGraphs); + return; + } + // 之前有存储过,则需要更新文件信息 + ViewFileList[filePath].Update(offset, tags, newGraphs); + } + + + uint64_t GetReaderOffsetByFilePath(std::string filePath) { + if (ViewFileList.find(filePath) != ViewFileList.end()) { + return ViewFileList[filePath].offset_; + } + return 0; + } + std::map GetHistoInfoByFilePath(std::string filePath) { + std::map info; + if (ViewFileList.find(filePath) != ViewFileList.end()) { + return ViewFileList[filePath].tagTohistoGraph_; + } + return info; + } + + // 文件监控相关操作 + void StartFileWatch(const std::string &rootPath) { + // 启动文件监控线程 + fileMonitor.StartMonitoring(rootPath); + } + void StopFileWatch() { + fileMonitor.StopMonitoring(); + } + std::vector GetNewFiles() { + return fileMonitor.GetNewFiles(); + } + + // 需要新增一个清空,当import命令进来的时候这个单例要清空掉,并停止文件监控。重新开始记录 + void Reset() { + ViewFileList.clear(); + StopFileWatch(); + } +private: + // 私有构造函数防止实例化 + HistogramViewManager() {} + ~HistogramViewManager() { + StopFileWatch(); + } + // 已经在界面呈现的文件 文件名-文件信息 + std::unordered_map ViewFileList; + FileMonitor fileMonitor; +}; + +} // Histo +#endif //BOARD_SCALARVISUALLYSERVER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Scalar/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9711ef2642dd29c2014a8b4a14f67ddd3330946 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/CMakeLists.txt @@ -0,0 +1,2 @@ +set(SCALAR_PROJECT_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +add_subdirectory(server) \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/.gitignore b/plugins/mindstudio-insight-plugins/Scalar/server/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..43bdbcd1e07f8044d4a47a23268a9743a4b4cb8f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/.gitignore @@ -0,0 +1,4 @@ +cmake-build-* +output +*.pb.cc +*.pb.h \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Scalar/server/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2c9f22cfad9cf5b0d09bd538ed0a25c3c14e96d --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.20) +project(Scalar) + +set(CMAKE_BUILD_TYPE Debug CACHE STRING "Build type" FORCE) + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) + message(FATAL_ERROR "GCC version must be 7.3.0 and above, but found ${CMAKE_CXX_COMPILER_VERSION}") + elseif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.4.0) + message(WARNING "GCC version ${CMAKE_CXX_COMPILER_VERSION} is greater than 11.4.0, may cause unknown problems.") + endif() +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) + +set(HOME_DIR ${PROJECT_SOURCE_DIR}) +set(EXECUTABLE_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/plugins) +set(LIBRARY_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/plugins) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector-all -D_FORTIFY_SOURCE=2 -O0 -g -ftrapv -fstack-protector-strong -fPIE -fPIC") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstack-protector-all -D_FORTIFY_SOURCE=2 -O0 -g -ftrapv -fstack-protector-strong -fPIE -fPIC") +if (${CMAKE_BUILD_TYPE} MATCHES "Debug") + message(STATUS "Enable debug symbol table, change optimization level to 0") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -g -O0") +endif () + +set(CMAKE_SKIP_RPATH true) +if (CMAKE_SYSTEM_NAME MATCHES "windows") + if ((NOT CMAKE_BUILD_TYPE MATCHES "Debug") AND (NOT CMAKE_BUILD_TYPE MATCHES "PROFILE")) + message(STATUS "Build type = ${CMAKE_BUILD_TYPE}, static = enable.") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--nxcompat -Wl,--dynamicbase -s -pie -Wincompatible-pointer-types") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wl,--nxcompat -Wl,--dynamicbase -s -pie -Wincompatible-pointer-types") + add_link_options(-static) + endif() +elseif() + if ((NOT CMAKE_BUILD_TYPE MATCHES "Debug") AND (NOT CMAKE_BUILD_TYPE MATCHES "PROFILE")) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -s -pie -Wl,-z,now") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -s -pie -Wl,-z,now") + endif() +endif() + + +add_subdirectory(src) \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/CMakeLists.txt b/plugins/mindstudio-insight-plugins/Scalar/server/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e87c7fa7a6cf1e4dd33bd2f25136628ad6e49716 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/CMakeLists.txt @@ -0,0 +1,70 @@ +set(SRC_HOME_DIR ${HOME_DIR}/src) +aux_source_directory(plugin SCALAR_SRC_LIST) +aux_source_directory(plugin/Handler SCALAR_SRC_LIST) +aux_source_directory(parser SCALAR_SRC_LIST) +aux_source_directory(Sampler SCALAR_SRC_LIST) +aux_source_directory(${PROJECT_ROOT_DIR}/proto SCALAR_SRC_LIST) +aux_source_directory(defs SCALAR_SRC_LIST) +aux_source_directory(FileManager SCALAR_SRC_LIST) +aux_source_directory(GraphManager SCALAR_SRC_LIST) +aux_source_directory(Util SCALAR_SRC_LIST) +set(LOG_SRC ${PROJECT_ROOT_DIR}/plugin_core/src/Logger.cpp) + +list(APPEND ${PROJECT_NAME}_SRC + ${PROTO_SRC} + ${SCALAR_SRC_LIST} + ${LOG_SRC}) +include_directories(${SRC_HOME_DIR} + ${SRC_HOME_DIR}/plugin + ${SRC_HOME_DIR}/parser + ${SRC_HOME_DIR}/Sampler + ${PROJECT_ROOT_DIR}/proto + ${PROJECT_ROOT_DIR} +) +set(LIBRARY_OUTPUT_PATH ${LIBRARY_OUTPUT_PATH}/${PROJECT_NAME}) +add_library(${PROJECT_NAME} SHARED ${${PROJECT_NAME}_SRC}) +target_include_directories(${PROJECT_NAME} PRIVATE ${${PROJECT_NAME}_H}) +target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/rapidjson/include/rapidjson) +if (${CMAKE_BUILD_TYPE} MATCHES "Debug") + message(STATUS "Open Debug Info") + target_compile_options(${PROJECT_NAME} PRIVATE -O0 -g) +endif () +target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/plugin_core/include) +target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so) +target_link_libraries(${PROJECT_NAME} PRIVATE mindboard::protobuf) +set(CMAKE_CXX_VISIBILITY_PRESET default) +if (${CMAKE_SYSTEM_NAME} MATCHES "Linux") + target_link_libraries(${PROJECT_NAME} PRIVATE stdc++fs) +endif () + +#-------- test -------- + +if (ENABLE_TESTCASES OR ENABLE_CPP_ST) + enable_testing() + aux_source_directory(test TEST_SRC) + aux_source_directory(${PROJECT_ROOT_DIR}/plugin_core/src TEST_SRC) + list(APPEND ${PROJECT_NAME}_TEST_SRC + ${${PROJECT_NAME}_SRC} + ${TEST_SRC}) + list(APPEND ${PROJECT_NAME}_TEST_H + ${${PROJECT_NAME}_H}) + add_executable(${PROJECT_NAME}_test ${${PROJECT_NAME}_TEST_SRC}) + message(STATUS "Dir:${PROJECT_ROOT_DIR}") + target_include_directories(${PROJECT_NAME}_test PUBLIC ${${PROJECT_NAME}_TEST_H}) + target_include_directories(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}/plugin_core/include) + target_include_directories(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}) + target_include_directories(${PROJECT_NAME}_test PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/plugin) + # if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(STATUS "Open Debug ${CMAKE_BUILD_TYPE}") + target_link_libraries(${PROJECT_NAME}_test PRIVATE stdc++fs) + # endif () + target_link_libraries(${PROJECT_NAME}_test PRIVATE mindboard::protobuf) + target_link_libraries(${PROJECT_NAME}_test PUBLIC gtest_main gmock_main) + target_link_libraries(${PROJECT_NAME}_test PUBLIC ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so) + target_link_libraries(${PROJECT_NAME}_test PRIVATE mindboard::protobuf) + if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") + target_compile_options(${PROJECT_NAME}_test PRIVATE -O0 -g) + endif () + + set_target_properties(${PROJECT_NAME}_test PROPERTIES EXCLUDE_FROM_ALL true) +endif () \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileInfoManager.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileInfoManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2a6829045d5007fb81b4334ce048dc07b09b89f7 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileInfoManager.cpp @@ -0,0 +1,53 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "FileInfoManager.h" +#include "Logger.h" + +using namespace Insight::Scalar::FileInfo; +using namespace Insight; + +std::shared_ptr FileInfoManager::AddFile(const std::string &filePath, ParseDataType dataType) { + if (filePaths_.count(filePath) != 0) { + return fileInfoMap_[filePath]; + } + auto fileInfo = std::make_shared(filePath, dataType); + filePaths_.insert(filePath); + fileInfoMap_[filePath] = fileInfo; + return fileInfo; +} + +bool FileInfoManager::DelFile(std::string_view filePath) { + if (filePaths_.count(filePath.data()) == 0) { + return false; + } + filePaths_.erase(filePath.data()); + fileInfoMap_.erase(filePath.data()); + return true; +} + +std::shared_ptr FileInfoManager::GetFileInfo(std::string_view filePath) { + if (filePaths_.count(filePath.data()) == 0) { + return nullptr; + } + return fileInfoMap_[filePath.data()]; +} + +void FileInfoManager::Reset() { + filePaths_.clear(); + fileInfoMap_.clear(); +} + +void FileInfoManager::OnFileCreate(std::string &&dir, std::string &&fileName) { + LOG(LogRank::Error) << "Watched new file under " << dir << ", file:" << fileName; + createFileGroupByDir_[std::move(dir)].emplace(std::move(fileName)); +} + +std::unordered_map> FileInfoManager::GetCreatedFileGroupByDir() { + auto res = std::move(createFileGroupByDir_); + createFileGroupByDir_ = std::unordered_map>(); + return res; +} + + + diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileInfoManager.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileInfoManager.h new file mode 100644 index 0000000000000000000000000000000000000000..a851bd6b647ad5abc43b0fb53a7ebd4c17809c86 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileInfoManager.h @@ -0,0 +1,93 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef FILEINFOMANAGER_H +#define FILEINFOMANAGER_H + +#include +#include +#include +#include +#include +#include "ParserFactory.h" + +namespace Insight::Scalar::FileInfo { +enum class UpdateType { + ADD, + DEL +}; + +struct FileInfo { + FileInfo(std::string filePath, ParseDataType dataType) : filePath_(std::move(filePath)), + parseDataType_(dataType) {}; + std::string filePath_; + ParseDataType parseDataType_; + uint64_t offSet_{0}; + bool empty_{false}; + bool imported_{true}; +}; + +class FileInfoManager { +public: + /** + * @brief add file object, return the ptr of new object, if already exist, return the ptr old object + * @param filePath + * @param dataType + * @return + */ + std::shared_ptr AddFile(const std::string &filePath, ParseDataType dataType); + + /** + * @brief delete file object, if success return true + * @param filePath + * @return + */ + bool DelFile(std::string_view filePath); + + std::shared_ptr GetFileInfo(std::string_view filePath); + + void Reset(); + + void OnFileCreate(std::string &&dir, std::string &&fileName); + + std::unordered_map> GetCreatedFileGroupByDir(); + + static inline ParseDataType GetFileType(std::string_view filePath) { + for (const auto &[k, v]: fileTypeMap_) { + std::regex regex(k); + std::smatch match; + std::string fileName = fs::path(filePath).filename().string(); + if (std::regex_search(fileName, match, regex)) { + return v; + } + } + return ParseDataType::Unknown; + } + + static inline bool IsFileSupported(std::string_view path) { + return std::any_of(fileTypeMap_.begin(), fileTypeMap_.end(), [&path](auto fileTypeIter) { + std::string mathRegex = fileTypeIter.first; + std::regex regex(mathRegex.data()); + std::smatch match; + std::string fileName = fs::path(path).filename().string(); + if (!std::regex_search(fileName, match, regex)) { + return false; + } + return true; + }); + } + +private: + std::set filePaths_; + std::unordered_map> fileInfoMap_; + std::unordered_map> createFileGroupByDir_; + + inline static std::map fileTypeMap_ = { + {R"(out.tfevent)", Insight::Scalar::ParseDataType::TF_EVENT}, + {R"(out.events.summary)", Insight::Scalar::ParseDataType::MindSpore_Summary}, + {R"(worker_[0-9]+\.log)", Insight::Scalar::ParseDataType::TEXT_LOG} + }; +}; +} + +#endif //FILEINFOMANAGER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcher.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcher.h new file mode 100644 index 0000000000000000000000000000000000000000..677893ce7a9feee70ddf72cdedcd78a748cfad4c --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcher.h @@ -0,0 +1,34 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef FILEWATCH_H +#define FILEWATCH_H + +#include +#include +#include +#include +#include +#include "defs/ConceptDefs.h" + +namespace Insight::Scalar::FileWatch { +class FileWatcher { +public: + FileWatcher() = default; + + virtual void Init() = 0; + + virtual void AddWatchPath(const std::vector &watchFileList) {}; + + virtual void DelWatchPath(const std::vector &DelFileList) {}; + + virtual void OnFileWriteClose(std::string &&dir, std::string &&fileName) {}; + + virtual void OnFileCreated(std::string &&dir, std::string &&fileName) {}; + + virtual void Reset() = 0; + + virtual ~FileWatcher() = default; +}; +} +#endif //FILEWATCH_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherFactory.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..5cccb49c6266cdb8ab4511d91cee624684e94a3a --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherFactory.h @@ -0,0 +1,33 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef BOARD_PLUGINS_SCALARVISUALLY_SRC_FILEMANAGER_FILEWATCHERFACTORY_H_ +#define BOARD_PLUGINS_SCALARVISUALLY_SRC_FILEMANAGER_FILEWATCHERFACTORY_H_ + +#include +#include +#include +#include +#include +#include "defs/ConceptDefs.h" +#include "FileWatcherLinuxImpl.h" + +namespace Insight::Scalar::FileWatch { + +class FileWatcherFactory { +public: + static inline std::unique_ptr GetFileWatcher() { +#ifdef __linux__ + return std::make_unique(); +#elif defined(_WIN32) + // TODO windows platform impelement + return nullptr; +#elif defined(__APPLE__) + // TODO mac platform impelemnt + return nullptr; +#endif + return nullptr; + } +}; +} +#endif //BOARD_PLUGINS_SCALARVISUALLY_SRC_FILEMANAGER_FILEWATCHERFACTORY_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherLinuxImpl.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherLinuxImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..31a36a4f134260407b60e6521ddc18f136b59aad --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherLinuxImpl.cpp @@ -0,0 +1,184 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "FileWatcherLinuxImpl.h" +#ifdef __linux__ +#include +#include +#include +#include +#include +#include "Logger.h" +#include "plugin/ScalarVisuallyServer.h" + +using namespace Insight::Scalar::FileWatch; +using namespace Insight; +namespace fs = std::experimental::filesystem; +constexpr size_t BUFFER_SIZE = sizeof(struct inotify_event) + NAME_MAX + 1; + +void FileWatcherLinuxImpl::Init() { + inotifyFd_ = inotify_init1(IN_NONBLOCK | IN_CLOEXEC); + if (inotifyFd_ == -1) { + LOG(LogRank::Error) << "Init file watch, error:" << strerror(errno); + return; + } + watchThread_ = std::make_unique(WatchFunc, std::ref(*this)); + watchThread_->detach(); +} + +void FileWatcherLinuxImpl::AddWatchPath(const std::vector &watchFileList) { + if (watchFileList.empty()) { + return; + } + for (const auto &file: watchFileList) { + fs::path path(file); + std::string dir = path.parent_path(); + std::string fileName = path.filename(); + if (IsFileWatched(dir, fileName)) { + LOG(LogRank::Info) << "Already watched, not add repeatedly"; + continue; + } + int fd = inotify_add_watch(inotifyFd_, dir.c_str(), watchFlag_); + if (fd == -1) { + LOG(LogRank::Error) << "Add File Watch failed, error=" << strerror(errno) << ", path=" << file; + continue; + } + std::unique_lock lock(mutex_); + watchDirFd_.emplace(dir, fd); + watchFdDir_.emplace(fd, dir); + watchDirToFiles_[dir].insert(fileName); + } +} + +void FileWatcherLinuxImpl::WaitStopWatchThread(unsigned int millSeconds) { + { + std::lock_guard lock(mutex_); + stop_ = true; + } + std::unique_lock lock(exitMutex_); + exitCv.wait_for(lock, std::chrono::milliseconds(millSeconds)); + close(inotifyFd_); +} + +void FileWatcherLinuxImpl::DelWatchPath(const std::vector &DelFileList) { + if (DelFileList.empty()) { + LOG(LogRank::Error) << "Del watch failed, path empty"; + return; + } + for (const auto &file: DelFileList) { + fs::path path(file); + std::string dir = path.root_directory(); + std::string fileName = path.filename(); + if (!IsFileWatched(dir, fileName)) { + LOG(LogRank::Info) << "Never watch this file"; + continue; + } + std::lock_guard lock(mutex_); + int fd = watchDirFd_[dir]; + if (inotify_rm_watch(inotifyFd_, fd) == -1) { + LOG(LogRank::Error) << "Delete failed, error=" << strerror(errno); + continue; + } + watchDirToFiles_[dir].erase(fileName); + if (watchDirToFiles_[dir].empty()) { + watchDirFd_.erase(dir); + watchFdDir_.erase(fd); + watchDirToFiles_.erase(dir); + } + } +} + +void FileWatcherLinuxImpl::NotifyWatchThreadStopped() { + + std::unique_lock lock(exitMutex_); + exitCv.notify_all(); +} + +void FileWatcherLinuxImpl::WatchFunc(Insight::Scalar::FileWatch::FileWatcherLinuxImpl &watcher) { + char eventBuff[BUFFER_SIZE] = {0}; + struct inotify_event *event = nullptr; + while (!watcher.stop_) { + // read the inotify message + memset(eventBuff, 0, sizeof(eventBuff)); + ssize_t result = read(watcher.inotifyFd_, eventBuff, sizeof(eventBuff)); + if (result <= 0) { + continue; + } + event = reinterpret_cast(eventBuff); + if (!watcher.IsDirWatched(event->wd)) { + continue; + } + std::string dir = watcher.GetWatchedDirName(event->wd); + std::string fileName = std::string(event->name); + if ((event->mask & IN_MODIFY) || (event->mask & IN_CLOSE_WRITE)) { + watcher.OnFileWriteClose(std::move(dir), std::move(fileName)); + } else if ((event->mask & IN_CREATE) || (event->mask & IN_MOVED_TO)) { + watcher.OnFileCreated(std::move(dir), std::move(fileName)); + } + } + LOG(LogRank::Info) << "Exit watch thread"; + watcher.NotifyWatchThreadStopped(); +} + +void FileWatcherLinuxImpl::OnFileCreated(std::string &&dir, std::string &&fileName) { + auto &server = ScalarVisuallyServer::Instance(); + LOG(LogRank::Info) << "Find new file in dir:" << dir << ", filename:" << fileName; + auto fileInfo = server.AddFile(dir + "/" + fileName); + fileInfo->empty_ = true; + fileInfo->imported_ = false; + server.OnFileCreate(std::move(dir), std::move(fileName)); +} + +void FileWatcherLinuxImpl::OnFileWriteClose(std::string &&dir, std::string &&fileName) { + auto &server = ScalarVisuallyServer::Instance(); + if (!server.IsFileImported(dir + "/" + fileName)) { + return; + } + server.OnFileDataUpdate(std::move(dir), std::move(fileName)); +} + +void FileWatcherLinuxImpl::Reset() { + std::lock_guard lock(mutex_); + for (auto &[path, fd]: watchDirFd_) { + inotify_rm_watch(inotifyFd_, fd); + } + watchDirFd_.clear(); + watchFdDir_.clear(); +} + +bool FileWatcherLinuxImpl::IsDirWatched(const std::string &dir) { + std::shared_lock lock(mutex_); + return watchDirFd_.count(dir) != 0; +} + +bool FileWatcherLinuxImpl::IsFileWatched(const std::string &dir, const std::string &file) { + if (!IsDirWatched(dir)) { + return false; + } + std::shared_lock lock(mutex_); + return watchDirToFiles_[dir].count(file) != 0; +} + +bool FileWatcherLinuxImpl::IsDirWatched(int wd) { + if (wd == -1) { + return false; + } + std::shared_lock lock(mutex_); + if (watchFdDir_.count(wd) == 0) { + return false; + } + return true; +} + +std::string FileWatcherLinuxImpl::GetWatchedDirName(int wd) { + if (wd == -1) { + return ""; + } + std::shared_lock lock(mutex_); + if (watchFdDir_.count(wd) == 0) { + return ""; + } + return watchFdDir_.at(wd); +} + +#endif diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherLinuxImpl.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherLinuxImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..d7e90effb94cc5ca8a4a15ac1b153cc49b67d5d5 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/FileManager/FileWatcherLinuxImpl.h @@ -0,0 +1,50 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef FILEWATCHLINUXIMPL_CPP_H +#define FILEWATCHLINUXIMPL_CPP_H + +#include "FileWatcher.h" +#ifdef __linux__ +#include +#include +#include +#include +#include +#include + +namespace Insight::Scalar::FileWatch { +class FileWatcherLinuxImpl : public FileWatcher { +public: + FileWatcherLinuxImpl() = default; + + void Init() override; + void AddWatchPath(const std::vector &watchFileList) override; + void WaitStopWatchThread(unsigned int millSeconds); + static void WatchFunc(FileWatcherLinuxImpl &watcher); + void DelWatchPath(const std::vector &DelFileList) override; + void OnFileCreated(std::string &&dir, std::string &&fileName) override; + void OnFileWriteClose(std::string &&dir, std::string &&fileName) override; + void Reset() override; + bool IsDirWatched(const std::string &dir); + bool IsDirWatched(int wd); + bool IsFileWatched(const std::string &dir, const std::string &file); + std::string GetWatchedDirName(int wd); + void NotifyWatchThreadStopped(); + ~FileWatcherLinuxImpl() override = default; + + std::unique_ptr watchThread_; + int inotifyFd_ = -1; + std::unordered_map watchDirFd_; + std::unordered_map watchFdDir_; + std::unordered_map> watchDirToFiles_; + std::shared_mutex mutex_; + std::mutex exitMutex_; + std::condition_variable exitCv; + bool stop_ = false; + inline static int watchFlag_ = IN_CLOSE_WRITE | IN_MODIFY | IN_MOVED_TO | IN_CREATE; +}; +} +#endif // __linux__ + +#endif //FILEWATCHLINUXIMPL_CPP_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/GraphManager/GraphManager.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/GraphManager/GraphManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f36990c7a7912a149960b5112fdf2227a38310b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/GraphManager/GraphManager.cpp @@ -0,0 +1,136 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include +#include +#include "Logger.h" +#include "SamplerFactory.h" +#include "Util/ScalaryProtocolUtil.h" +#include "GraphManager.h" + +using namespace Insight::Scalar::GraphOp; +using namespace Insight; +using namespace Protocol; + +std::optional +Graph::GetFileData(const std::string &file, uint64_t offset, uint64_t sampleOffset, + std::unique_ptr sampleParam) { + GraphData data; + data.tag_ = tag_; + data.filePath_ = file; + if (!InnerFile(file)) { + return data; + } + auto &dataSrc = dataMap_[file]; + offset = offset > dataSrc.size() ? dataSrc.size() : offset; + std::copy(dataSrc.begin() + static_cast(offset), dataSrc.end(), std::back_inserter(data.graphData_)); + if (!sampleParam->algorithm_.empty()) { + UpdateSample(file, std::move(sampleParam)); + GetSampledData(file, sampleOffset, data); + } + return data; +} + +void Graph::UpdateData(const std::string &file, std::vector &&data) { + if (!InnerFile(file)) { + dataFiles_.insert(file); + dataMap_[file].clear(); + } + auto &dst = dataMap_[file]; + std::move(data.begin(), data.end(), std::back_inserter(dst)); +} + +bool Graph::InnerFile(const std::string &file) { + return dataFiles_.count(file) != 0; +} + +std::vector Graph::GetDataFiles() { + std::vector res; + std::copy(dataFiles_.begin(), dataFiles_.end(), std::back_inserter(res)); + return res; +} + +void Graph::UpdateSample(const std::string &file, std::unique_ptr sampleParam) { + if (sampleParam == nullptr || sampleParam->algorithm_.empty()) { + return; + } + // init or different sampler + if (samplers_.find(file) == samplers_.end() || samplers_.at(file)->GetAlgorithm() != sampleParam->algorithm_) { + auto sampler = SampleFactory::Instance().GetSampler(sampleParam->algorithm_); + sampler->SetSampleParam(std::move(sampleParam)); + samplers_[file] = std::move(sampler); + return; + } + // replace, here no data race + if (!samplers_.at(file)->IsSameParam(sampleParam.get())) { + samplers_.at(file)->SetSampleParam(std::move(sampleParam)); + } +} + +void Graph::GetSampledData(const std::string &file, uint64_t sampleOffset, GraphData &data) { + auto &dataSrc = dataMap_[file]; + auto samplerPtr = GetSampler(file); + if (samplerPtr == nullptr) { + return; + } + data.graphSampledData_ = std::move(samplerPtr->Sample(dataSrc, sampleOffset, -1)); +} + +Sample::SamplerBase *Graph::GetSampler(const std::string &file) { + if (samplers_.find(file) != samplers_.end()) { + return samplers_[file].get(); + } + return nullptr; +} + +std::optional GraphManager::GetGraphData(const SingleGraphReqInfo &reqInfo) { + if (!GraphExits(reqInfo.tag_)) { + LOG(LogRank::Error) << "Get data failed, no such graph, tag:" << reqInfo.tag_; + return std::nullopt; + } + auto graph = GetGraph(reqInfo.tag_); + return graph->GetFileData(reqInfo.file_, reqInfo.offset_, reqInfo.sampleOffset_, + std::make_unique(reqInfo.sampleAlgorithm_, reqInfo.sampleWeight_)); +} + +void GraphManager::UpdateGraphData(const std::string &tag, const std::string &file, std::vector &&data) { + if (!GraphExits(tag)) { + auto graph = std::make_shared(tag); + graphs_.emplace(tag, graph); + } + std::shared_ptr graph = GetGraph(tag); + if (graph) { + return graph->UpdateData(file, std::move(data)); + } +} + +std::shared_ptr GraphManager::GetGraph(const std::string &tag) { + if (!GraphExits(tag)) { + return nullptr; + } + return graphs_.at(tag); +} + +bool GraphManager::GraphExits(const std::string &tag) { + return graphs_.find(tag) != graphs_.end(); +} + +void GraphManager::Reset() { + graphs_.clear(); +} + +std::unordered_map> GraphManager::GetAllGraphInfo() { + std::unordered_map> res; + for (const auto &[tag, graph]: graphs_) { + res[tag] = graph->GetDataFiles(); + } + return res; +} + +void GraphManager::GetFileTags(std::string &path, std::set &tags) { + for (auto &[tag, graph]: graphs_) { + if (graph->InnerFile(path)) { + tags.insert(tag); + } + } +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/GraphManager/GraphManager.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/GraphManager/GraphManager.h new file mode 100644 index 0000000000000000000000000000000000000000..70a7766f6d129435da9268504993578f806dcb5c --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/GraphManager/GraphManager.h @@ -0,0 +1,76 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef GRAPHMANAGER_H +#define GRAPHMANAGER_H + +#include +#include +#include +#include +#include +#include +#include +#include "defs/ConceptDefs.h" +#include "SamplerBase.h" +#include "Util/ScalaryProtocolUtil.h" + +namespace Insight::Scalar::GraphOp { +using namespace Protocol; +struct GraphData { + std::string tag_; + std::string filePath_; + std::vector graphData_; + std::vector graphSampledData_; +}; + +class Graph { +public: + Graph() = default; + /** + * @brief constructor of Graph, the graphId and tag is need + * @param graphId + * @param tag + */ + explicit Graph(std::string tag) : tag_(std::move(tag)) { + }; + void UpdateData(const std::string &file, std::vector &&data); + std::optional + GetFileData(const std::string &file, uint64_t offset, uint64_t sampleOffset, + std::unique_ptr sampleParam); + std::vector GetDataFiles(); + bool InnerFile(const std::string &file); + +private: + void UpdateSample(const std::string &file, std::unique_ptr sampleParam); + void GetSampledData(const std::string &file, uint64_t sampleOffset, GraphData &data); + Sample::SamplerBase *GetSampler(const std::string &file); + + std::string tag_; + std::set dataFiles_; + std::unordered_map> dataMap_; + std::unordered_map> samplers_; +}; + +class GraphManager { +public: + /** + * @brief + * @param tag + * @param file + * @param offset + * @return + */ + std::optional GetGraphData(const SingleGraphReqInfo &reqInfo); + void UpdateGraphData(const std::string &tag, const std::string &file, std::vector &&data); + std::shared_ptr GetGraph(const std::string &tag); + void Reset(); + std::unordered_map> GetAllGraphInfo(); + void GetFileTags(std::string &path, std::set &tags); +private: + bool GraphExits(const std::string &tag); + + std::unordered_map> graphs_; +}; +} +#endif //GRAPHMANAGER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/IRLowPassSampler.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/IRLowPassSampler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33bfcfa9d29c74963a783a92ecfa3d5f13a08dd6 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/IRLowPassSampler.cpp @@ -0,0 +1,75 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "IRLowPassSampler.h" +#include +#include +#include + +using namespace Insight::Scalar::Sample; + +void IRLowPassSampler::SetSampleParam(std::unique_ptr param) { + if (param == nullptr) { + return; + } + if (param->weight != weight_) { + last_ = 0.0; + numAccum_ = 0; + weight_ = param->weight; + } +} + +std::vector +IRLowPassSampler::Sample(const std::vector &original, uint32_t start, int32_t length) { + if (start > original.size() || weight_ == 0.0) { + return {}; + } + if (start == 0) // frontend reopen or fresh, need clean cache + { + last_ = 0.0; + numAccum_ = 0; + } + size_t dataLen = original.size() - start; + if (length != -1) // -1 + { + dataLen = length > dataLen ? dataLen : length; + } + std::vector result; + result.reserve(dataLen); + float firstValue = original[0].value_; + bool isConstant = std::all_of(original.begin(), original.end(), [&firstValue](const auto &point) { + return point.value_ == firstValue; + }); + for (size_t index = 0; index < dataLen; index++) { + auto &point = original[index]; + ScalarPoint sampledPoint{}; + sampledPoint.step_ = point.step_; + if (isConstant || std::isinf(point.value_) || std::isnan(point.value_ + )) { + sampledPoint.value_ = point.value_; + } else { + last_ = last_ * weight_ + (1 - weight_) * point.value_; + numAccum_++; + float debiasWeight = 1.0; + if (weight_ != 1.0) { + debiasWeight = debiasWeight - static_cast(pow(weight_, numAccum_)); + } + sampledPoint.value_ = last_ / debiasWeight; + } + result.emplace_back(sampledPoint); + } + return result; +} + +bool IRLowPassSampler::IsSameParam(SampleParam *param) { + if (param == nullptr) { + return false; + } + if (param->algorithm_ != algorithm_) { + return false; + } + if (param->weight != weight_) { + return false; + } + return true; +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/IRLowPassSampler.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/IRLowPassSampler.h new file mode 100644 index 0000000000000000000000000000000000000000..c6a9dea6f51e39b785f556b0dac03cb50e7213da --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/IRLowPassSampler.h @@ -0,0 +1,30 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef MINDSTUDIO_BOARD_PLUGINS_MINDSTUDIO_INSIGHT_PLUGINS_SCALAR_SERVER_SRC_SAMPLER_IRLOWPASSSAMPLER_H_ +#define MINDSTUDIO_BOARD_PLUGINS_MINDSTUDIO_INSIGHT_PLUGINS_SCALAR_SERVER_SRC_SAMPLER_IRLOWPASSSAMPLER_H_ + +#include "SamplerBase.h" +#include + +using namespace Insight::Scalar::Sample; + +class IRLowPassSampler : public SamplerBase { +public: + IRLowPassSampler() { + algorithm_ = "smoothing"; + }; + + void SetSampleParam(std::unique_ptr param) override; + + std::vector Sample(const std::vector &original, uint32_t start, int32_t length) override; + + bool IsSameParam(SampleParam *param) override; + +private: + float weight_{0.0}; + float last_{0.0}; + uint32_t numAccum_{0}; +}; + +#endif //MINDSTUDIO_BOARD_PLUGINS_MINDSTUDIO_INSIGHT_PLUGINS_SCALAR_SERVER_SRC_SAMPLER_IRLOWPASSSAMPLER_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/SamplerBase.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/SamplerBase.h new file mode 100644 index 0000000000000000000000000000000000000000..2b1146d905f2ce71ce08233f2fc8ec32be4ec1ee --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/SamplerBase.h @@ -0,0 +1,43 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef MINDSTUDIO_BOARD_PLUGINS_MINDSTUDIO_INSIGHT_PLUGINS_SCALAR_SERVER_SRC_SAMPLER_SAMPLERBASE_H_ +#define MINDSTUDIO_BOARD_PLUGINS_MINDSTUDIO_INSIGHT_PLUGINS_SCALAR_SERVER_SRC_SAMPLER_SAMPLERBASE_H_ + +#include +#include +#include "defs/ConceptDefs.h" + +using namespace Insight::Scalar; +namespace Insight::Scalar::Sample { +struct SampleParam { + SampleParam() = default; + + SampleParam(std::string algorithm, float weight) : algorithm_(std::move(algorithm)), weight(weight) {} + + std::string algorithm_; + float weight{0}; +}; + +/** + * @brief use for data sampler + */ +class SamplerBase { +public: + SamplerBase() = default; + + virtual std::string GetAlgorithm() { return algorithm_; } + + virtual ~SamplerBase() = default; + + virtual void SetSampleParam(std::unique_ptr param) = 0; + + virtual std::vector Sample(const std::vector &original, uint32_t start, int32_t end) = 0; + + virtual bool IsSameParam(SampleParam *param) = 0; + +protected: + std::string algorithm_; +}; +} +#endif //MINDSTUDIO_BOARD_PLUGINS_MINDSTUDIO_INSIGHT_PLUGINS_SCALAR_SERVER_SRC_SAMPLER_SAMPLERBASE_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/SamplerFactory.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/SamplerFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..083af99835bec765c7ecc0c9a5f4b3a3d37c0928 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/Sampler/SamplerFactory.h @@ -0,0 +1,43 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MINDSTUDIO_BOARD_SAMPLERFACTORY_H +#define MINDSTUDIO_BOARD_SAMPLERFACTORY_H + +#include "SamplerBase.h" +#include "IRLowPassSampler.h" +#include +#include +#include +#include +#include + +namespace Insight::Scalar::Sample { +class SampleFactory { +public: + static SampleFactory &Instance() { + static SampleFactory instance; + return instance; + } + + SampleFactory() { + sampleMap_.emplace("smoothing", []() { + return std::make_unique(); + }); + } + + std::unique_ptr GetSampler(std::string_view algorithm) { + auto it = sampleMap_.find(algorithm); + if (it == sampleMap_.end()) { + return nullptr; + } + auto func = it->second; + return func(); + } + +private: + std::unordered_map()>> sampleMap_{}; +}; + +} +#endif //MINDSTUDIO_BOARD_SAMPLERFACTORY_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/Util/FileUtil.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/Util/FileUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..ad56ba1446298051d4af6541948c5693c03c9039 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/Util/FileUtil.h @@ -0,0 +1,149 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_PLUGINS_SCALARVISUALLY_SRC_PROTOCOLUTIL_FILEUTIL_H_ +#define BOARD_PLUGINS_SCALARVISUALLY_SRC_PROTOCOLUTIL_FILEUTIL_H_ + +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#include +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +namespace Insight::Scalar { +class StrUtil { +public: + static inline uint8_t ByteNum(uint8_t byte) { + const static std::map MAP_BYTES = { + {0xE0, 0xC0}, + {0xF0, 0xE0}, + {0xF8, 0xF0}, + {0xFC, 0xF8}, + {0xFE, 0xFC} + }; + uint8_t index = 0; + uint8_t byteNum = 0; + for (auto [k, v]: MAP_BYTES) { + index++; + if ((byte & k) == v) { + byteNum = index; + break; + } + } + return byteNum; + } + + static inline bool IsUtf8String(const std::string &str) { + uint32_t byteNum = 0; + for (const char c: str) { + const auto byte = static_cast(c); + if ((byteNum == 0) && ((byte & 0x80) == 0)) { + continue; + } + if (byteNum == 0) { + byteNum = ByteNum(byte); + if (byteNum == 0) { + return false; + } + } else { + if ((byte & 0xC0) != 0x80) { + return false; + } + byteNum--; + } + } + return true; + } + +#ifdef _WIN32 + static inline std::string Utf8ToGbk(const char *src) + { + if (src == nullptr) { + return ""; + } + int len = MultiByteToWideChar(CP_UTF8, 0, src, -1, nullptr, 0); + const auto wstr = std::make_unique(len + 1); + MultiByteToWideChar(CP_UTF8, 0, src, -1, wstr.get(), len); + len = WideCharToMultiByte(CP_ACP, 0, wstr.get(), -1, nullptr, 0, nullptr, nullptr); + const auto str = std::make_unique(len + 1); + WideCharToMultiByte(CP_ACP, 0, wstr.get(), -1, str.get(), len, nullptr, nullptr); + return {str.get()}; + } + + static inline std::string GbkToUtf8(const char *src) + { + if (src == nullptr) { + return ""; + } + int len = MultiByteToWideChar(CP_ACP, 0, src, -1, nullptr, 0); + const auto wstr = std::make_unique(len + 1); + MultiByteToWideChar(CP_ACP, 0, src, -1, wstr.get(), len); + len = WideCharToMultiByte(CP_UTF8, 0, wstr.get(), -1, nullptr, 0, nullptr, nullptr); + const auto str = std::make_unique(len + 1); + WideCharToMultiByte(CP_UTF8, 0, wstr.get(), -1, str.get(), len, nullptr, nullptr); + return {str.get()}; + } +#endif +}; + +class FileUtil { +public: + static inline std::string PathPreProcess(std::string path) { +#ifdef WIN32 + if (StrUtil::IsUtf8String(path)) { + path = StrUtil::Utf8ToGbk(path.c_str()); + } +#endif + return path; + } + + static inline bool FindFolder(const std::string_view path, + std::vector &folders, + std::vector &files) { + if (path.empty()) { + return false; + } + if (!fs::exists(path) || !fs::is_directory(path)) { + return false; + } + for (auto &entry: fs::directory_iterator(path)) { + if (std::string name = entry.path().filename().string(); name == "." || name == "..") { + continue; + } + if (fs::is_directory(entry)) { + folders.emplace_back(entry.path().string()); + } else if (fs::is_regular_file(entry)) { + files.emplace_back(entry.path().string()); + } + } + return true; + } + + static inline void ScanFolderIf(std::string_view path, + std::vector &dst, + const std::function &func) { + if (path.empty() || !fs::exists(path)) { + return; + } + if (!fs::is_directory(path)) { + return; + } + for (const auto &entry: fs::directory_iterator(path)) { + if (func(entry)) { + dst.emplace_back(path); + } + } + } +}; +} +#endif //BOARD_PLUGINS_SCALARVISUALLY_SRC_PROTOCOLUTIL_FILEUTIL_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/Util/ScalaryProtocolUtil.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/Util/ScalaryProtocolUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..eb1aebcde28088aedd5ea2b723d44f7d1e0b75b2 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/Util/ScalaryProtocolUtil.h @@ -0,0 +1,168 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_PLUGINS_SCALARVISUALLY_SRC_PROTOCOLUTIL_PROTOCOLUTIL_H_ +#define BOARD_PLUGINS_SCALARVISUALLY_SRC_PROTOCOLUTIL_PROTOCOLUTIL_H_ + +#include "document.h" +#include "writer.h" + +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif + +#include +#include +#include +#include +#include + +namespace Insight::Scalar::Protocol { +struct SingleGraphReqInfo { + std::string tag_; + std::string file_; + uint64_t offset_; + std::string sampleAlgorithm_; + uint64_t sampleOffset_; + float sampleWeight_; +}; +struct GetScalarDataRequest { + std::vector data_; +}; +using json = rapidjson::Value; +using document_t = rapidjson::Document; + +template +static inline std::optional TryParseJson(std::string_view jsonStr, std::string &error) { + document_t doc; + doc.Parse(jsonStr.data(), jsonStr.length()); + if (doc.HasParseError()) { + constexpr size_t printErrorSize = 10; + auto offset = doc.GetErrorOffset(); + auto start = offset >= printErrorSize ? offset - printErrorSize : 0; + error = "Error code:" + std::to_string(doc.GetParseError()) + ", str:" + + std::string(jsonStr.substr(start, offset - start + printErrorSize)); + return std::nullopt; + } + return std::move(doc); +} + +static inline document_t ParseJsonToStr(std::string_view jsonStr) { + document_t doc; + doc.Parse(jsonStr.data(), jsonStr.length()); + return doc; +} + +static inline std::string DumpJsonToStr(const json &document) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + return {buffer.GetString(), buffer.GetSize()}; +} + + +template +static inline void AddJsonMember(json &dst, + std::string_view key, + T &&value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), std::forward(value), allocator); +} + +static inline void AddJsonMember(json &dst, + std::string_view key, + const std::string &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), + rapidjson::Value().SetString(value.data(), value.size(), allocator), + allocator); +} + +static inline void AddJsonMember(json &dst, + std::string_view key, + std::string &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), + rapidjson::Value().SetString(value.data(), value.size(), allocator), + allocator); +} + +static inline void AddJsonMember(json &dst, + std::string_view key, + std::string &&value, + rapidjson::MemoryPoolAllocator<> &allocator) { + dst.AddMember(rapidjson::StringRef(key.data(), key.length()), + rapidjson::Value().SetString(value.data(), value.size(), allocator), + allocator); +} + +template +static inline void AddJsonMember(json &dst, + std::string_view key, + std::vector &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + json temp(rapidjson::kArrayType); + for (const T &item: value) { + if constexpr (std::is_same_v) { + temp.PushBack(json().SetString(item.data(), item.size(), allocator), allocator); + } else { + temp.PushBack(std::forward(item), allocator); + } + } + AddJsonMember(dst, key, temp, allocator); +} + +template +static inline void AddJsonMember(json &dst, + std::string_view key, + std::set &value, + rapidjson::MemoryPoolAllocator<> &allocator) { + json temp(rapidjson::kArrayType); + for (const T &item: value) { + if constexpr (std::is_same_v) { + temp.PushBack(json().SetString(item.data(), item.size(), allocator), allocator); + } else { + temp.PushBack(std::forward(item), allocator); + } + } + AddJsonMember(dst, key, temp, allocator); +} + +static inline void SetResponseError(int errCode, const std::string &errMsg, std::string &resultStr) { + document_t result = ParseJsonToStr(resultStr); + auto &allocator = result.GetAllocator(); + result["errCode"].SetInt(errCode); + result["msg"].SetString(errMsg.c_str(), errMsg.size(), allocator); + result["result"].SetBool(false); + resultStr = DumpJsonToStr(result); +} + +static inline std::string GetBasicResponse() { + return R"({"body":{}, "msg":"", "errCode":0, "result":true})"; +} + +static inline std::string GetReadableFileName(std::string_view path) { + if (path.empty()) { + return ""; + } + auto curPath = fs::path(path); + auto fileName = curPath.filename(); + if (!curPath.has_parent_path()) { + return std::string(path); + } + auto parent Path = curPath.parent_path(); + if (!parentPath.has_filename()) { + return fileName.string(); + } + auto parentDir = paren tPath.filename(); + auto res = parentDir / fileName; + return res.string(); +} +} +#endif //BOARD_PLUGINS_SCALARVISUALLY_SRC_PROTOCOLUTIL_PROTOCOLUTIL_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/defs/ConceptDefs.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/defs/ConceptDefs.h new file mode 100644 index 0000000000000000000000000000000000000000..5d58804541270f7fa39bd13abdec54ee11a2c10b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/defs/ConceptDefs.h @@ -0,0 +1,41 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef CONCEPTDEFS_H +#define CONCEPTDEFS_H +#include "rapidjson.h" +#include "document.h" +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif +#include + +namespace Insight::Scalar { +using json_t = rapidjson::Value; +using document_t = rapidjson::Document; +enum class ParseDataType { + MindSpore_Summary = 0, + TF_EVENT = 1, + TEXT_LOG = 2, + Unknown = 3 +}; + +enum ErrCode : int { + OK = 0, + INVALID_REQUEST_JSON, + REQUEST_INVALID_PARAM, + INVALID_PATH +}; +struct ScalarPoint { + ScalarPoint() = default; + ScalarPoint(const int64_t step, const float value) : step_(step), value_(value) + {}; + int64_t step_{0}; + float value_{0}; +}; +} +#endif //CONCEPTDEFS_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/FileParser.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/FileParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24350509ddaedd6e5a3f829811fe5c663c15cb37 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/FileParser.cpp @@ -0,0 +1,82 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "FileParser.h" +#include "Logger.h" +#include "defs/ConceptDefs.h" + +using namespace Insight::Scalar::Parser; +using namespace Insight::Scalar; +using namespace Insight; + +bool FileParser::CheckFilePathVaild(const std::string &filePath) { + if (filePath.empty()) { + LOG(LogRank::Error) << "File path is empty"; + return false; + } + if (!fs::exists(filePath)) { + LOG(LogRank::Error) << "File path does not exist"; + return false; + } + auto readPermission = fs::status(filePath).permissions() & fs::perms::owner_read; + if (readPermission == fs::perms::none) { + LOG(LogRank::Error) << "File not permit to read"; + return false; + } + return true; +} + +std::ifstream FileParser::OpenFileSafe(const std::string &filePath) { + std::ifstream file; + file.setstate(std::ios::badbit); + if (!CheckFilePathVaild(filePath)) { + LOG(LogRank::Error) << "Open file failed"; + return file; + } + file = std::ifstream(filePath, std::ios::in | std::ios::binary); + return file; +} + +std::map > FileParser::ParserData(const std::string &filePath, uint64_t &offset) { + std::map > data; + std::ifstream file = OpenFileSafe(filePath); + if (!file.is_open()) { + LOG(LogRank::Warning) << "Parse data faild, open file error"; + return data; + } + file.seekg(static_cast(offset), std::ios::beg); + std::string recordStr; + while (file && ReadRecord(recordStr, file)) { + if (!ParseRecordToScalar(std::move(recordStr), data)) { + break; + } + std::streampos index = file.tellg(); + if (index != -1) { + offset = index; + } + } + return data; +} + +bool FileParser::ReadCheckSumRecord(std::ifstream &input, std::vector &buffer, size_t size) { + if (!input) { + return false; + } + if (size > std::numeric_limits::max()) { + LOG(LogRank::Error) << "Read data exceed limit"; + return false; + } + + buffer.clear(); + buffer.resize(size + 1); + input.read(buffer.data(), static_cast(size)); + if (input.gcount() != size) { + return false; + } + uint32_t ccrc = 0; + input.read(reinterpret_cast(&ccrc), sizeof(uint32_t)); + if (input.gcount() != sizeof(uint32_t)) { + return false; + } + return true; +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/FileParser.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/FileParser.h new file mode 100644 index 0000000000000000000000000000000000000000..8d9e4e5b525743233524978d5994a03a062fe71a --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/FileParser.h @@ -0,0 +1,49 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef FILEPARSER_H +#define FILEPARSER_H + +#include +#include +#include +#include +#include +#include +#include "defs/ConceptDefs.h" + +namespace Insight::Scalar::Parser { +/** + * @brief This base class of data file parser + */ +class FileParser { +public: + FileParser() : type_(ParseDataType::Unknown) {} + + /** + * @brief open file stream safely + * @param filePath + * @return + */ + std::ifstream OpenFileSafe(const std::string &filePath); + /** + * + * @param filePath + * @param[in/out] offset file read offset, update after parsed + * @return + */ + std::map> ParserData(const std::string &filePath, + uint64_t &offset); + virtual bool ReadRecord(std::string &recordStr, std::ifstream &input) = 0; + virtual bool ParseRecordToScalar(std::string &&record, std::map> &res) = 0; + static bool ReadCheckSumRecord(std::ifstream &input, std::vector &buffer, size_t size); + virtual ~FileParser() = default; +private: + bool CheckFilePathVaild(const std::string &filePath); + +public: + ParseDataType type_; +}; +} + +#endif //FILEPARSER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/LogTextParser.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/LogTextParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b322956f4c1b516e9bd0e46695ba4056f8beac9c --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/LogTextParser.cpp @@ -0,0 +1,44 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "LogTextParser.h" +#include "Logger.h" +#include + +using namespace Insight::Scalar; +using namespace Insight::Scalar::Parser; + +bool LogTextParser::ReadRecord(std::string &recordStr, std::ifstream &input) { + getline(input, recordStr); + return (!input.eof() && !input.bad()); +} + +bool LogTextParser::ParseRecordToScalar(std::string &&record, std::map> &res) { + if (record.empty()) { + return false; + } + // parse step information + int64_t step = 0; + const auto stepRegex = std::regex(R"((step:\[\s(\d+)/) | (iteration\s{5}(\d+)/))"); + if (std::smatch smatch; std::regex_search(record, smatch, stepRegex)) { + if (smatch.size() > 2 && smatch[2].matched) { + step = std::stoll(smatch[2].str()); + } + if (smatch.size() > 4 && smatch[4].matched) { + step = std::stoll(smatch[4].str()); + } + } else { + return true; + } + float value = 0; + const auto lossRegex = std::regex(R"(loss:\s([-+]?\d*\.?\d+([eE][-+]?\d+)?))"); + if (std::smatch smatch; std::regex_search(record, smatch, lossRegex)) { + if (smatch.size() > 2 && smatch[2].matched) { + value = std::stof(smatch[2].str()); + } + } else { + return true; + } + res[TEXT_DEFAULT_TAG.data()].emplace_back(step, value); + return true; +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/LogTextParser.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/LogTextParser.h new file mode 100644 index 0000000000000000000000000000000000000000..eede60cfdffdc7747dfa6c3172d315d57f713a14 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/LogTextParser.h @@ -0,0 +1,27 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef LOGTEXTPARSER_H +#define LOGTEXTPARSER_H + +#include "FileParser.h" +#include + +namespace Insight::Scalar::Parser { +constexpr std::string_view TEXT_DEFAULT_TAG = "Loss"; + +class LogTextParser : public FileParser { +public: + LogTextParser() { + type_ = ParseDataType::TEXT_LOG; + } + + bool ReadRecord(std::string &recordStr, std::ifstream &input) override; + + bool ParseRecordToScalar(std::string &&record, std::map> &res) override; + + ~LogTextParser() override = default; +}; +} + +#endif //LOGTEXTPARSER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/ParserFactory.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/ParserFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..3a6ef27a03bc09c878ad3a8efbe9c2d617d5681e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/ParserFactory.h @@ -0,0 +1,46 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef PARSERFACTORY_H +#define PARSERFACTORY_H + +#include + +#include "LogTextParser.h" +#include "SummaryParser.h" +#include "TFEventParser.h" + +namespace Insight::Scalar::Parser { +class ParserFactory { +public: + ParserFactory(const ParserFactory &) = delete; + + ParserFactory &operator=(const ParserFactory &) = delete; + + ParserFactory(ParserFactory &&) = delete; + + static ParserFactory &Instance() { + static ParserFactory instance; + return instance; + } + + std::shared_ptr CreateFileParse(ParseDataType dataType) { + if (parsers_.find(dataType) == parsers_.end()) { + return nullptr; + } + return parsers_[dataType]; + } + +private: + ParserFactory() = default; + + static inline std::map > parsers_ = { + {ParseDataType::TF_EVENT, std::make_shared()}, + {ParseDataType::MindSpore_Summary, std::make_shared()}, + {ParseDataType::TEXT_LOG, std::make_shared()}, + {ParseDataType::Unknown, nullptr} + }; +}; +} + +#endif //PARSERFACTORY_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/SummaryParser.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/SummaryParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cdda9fc3b1d1d1494c38d2fe18b9501e6d0a33b6 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/SummaryParser.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "SummaryParser.h" +#include "Logger.h" +#include "mindspore_summary.pb.h" + +using namespace Insight::Scalar::Parser; +using namespace Insight::Scalar; +using namespace Insight; + +bool SummaryParser::EventContainsScalar(const mindspore::irpb::Event &event) { + if (!event.has_summary()) { + return false; + } + const mindspore::irpb::Summary &summary = event.summary(); + bool hasScalarValue = + std::any_of(summary.value().begin(), summary.value().end(), + [](const mindspore::irpb::Summary::Value &value) { + return value.value_case() == mindspore::irpb::Summary::Value::kScalarValue; + }); + if (!hasScalarValue) { + return false; + } + return true; +} + +bool SummaryParser::ReadRecord(std::string &eventStr, std::ifstream &input) { + /* + * The structure of tf event file: + * | data length | CRC sum | pb data | CRC sum | + * | uint64_t | uint32_t | .... | uint32_t | + */ + std::vector record; + if (!ReadCheckSumRecord(input, record, sizeof(uint64_t))) { + return false; + } + uint64_t length = 0; + memcpy(&length, record.data(), sizeof(uint64_t)); + if (!ReadCheckSumRecord(input, record, length)) { + return false; + } + eventStr = std::string(record.data(), length); + return true; +} + +bool SummaryParser::ParseRecordToScalar(std::string &&record, std::map> &res) { + mindspore::irpb::Event event; + if (!event.ParseFromString(record)) { + LOG(LogRank::Error) << "Can't convert str to mindspore event"; + return false; + } + if (!EventContainsScalar(event)) { + // not contains scalar, skip + return true; + } + int64_t step = event.step(); + for (const auto &value: event.summary().value()) { + if (value.value_case() != mindspore::irpb::Summary::Value::kScalarValue) { + continue; + } + const std::string &tag = value.tag(); + float scalarValue = value.scalar_value(); + res[tag].emplace_back(step, scalarValue); + } + return true; +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/SummaryParser.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/SummaryParser.h new file mode 100644 index 0000000000000000000000000000000000000000..23893a114af05b5eb920f8c6ed7cfe80fe9bb63b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/SummaryParser.h @@ -0,0 +1,35 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef SUMMARYPARSER_H +#define SUMMARYPARSER_H + +#include "FileParser.h" +#include + +namespace Insight::Scalar::Parser { + +class SummaryParser : public FileParser { +public: + SummaryParser() { + type_ = ParseDataType::MindSpore_Summary; + } + + bool ReadRecord(std::string &eventStr, std::ifstream &input) override; + + bool ParseRecordToScalar(std::string &&record, std::map> &res) override; + + ~SummaryParser() override = default; + +private: + /** + * @brief check event wether contains scalar data + * @param event + * @return + */ + static bool EventContainsScalar(const mindspore::irpb::Event &event); +}; + +} + +#endif //SUMMARYPARSER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/TFEventParser.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/TFEventParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a78573c28defc7d17e5b103e8d3b1de4d8eba0cc --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/TFEventParser.cpp @@ -0,0 +1,71 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include "TFEventParser.h" +#include "proto/event.pb.h" +#include "Logger.h" + +using namespace Insight::Scalar::Parser; +using namespace Insight::Scalar; +using namespace Insight; + +bool TFEventParser::EventContainsScalar(const tensorboard::Event &event) { + if (!event.has_summary()) { + return false; + } + const auto &summary = event.summary(); + bool hasScalarValue = + std::any_of(summary.value().begin(), summary.value().end(), [](const tensorboard::Summary::Value &item) { + return item.value_case() == tensorboard::Summary::Value::kSimpleValue; + }); + if (!hasScalarValue) { + return false; + } + return true; +} + +bool TFEventParser::ReadRecord(std::string &eventStr, std::ifstream &input) { + /* + * The structure of tf event file: + * | data length | CRC sum | pb data | CRC sum | + * | uint64_t | uint32_t | .... | uint32_t | + */ + std::vector record; + if (!ReadCheckSumRecord(input, record, sizeof(uint64_t))) { + return false; + } + uint64_t length = 0; + memcpy(&length, record.data(), sizeof(uint64_t)); + if (!ReadCheckSumRecord(input, record, length)) { + return false; + } + eventStr = std::string(record.data(), length); + return true; +} + +bool TFEventParser::ParseRecordToScalar(std::string &&record, std::map> &res) { + tensorboard::Event event; + if (!event.ParseFromString(record)) { + LOG(LogRank::Error) << "Can't convert str to tensorflow event"; + return false; + } + if (!EventContainsScalar(event)) { + // event not contains scalar data, skip + return true; + } + const int64_t step = event.step(); + const tensorboard::Summary &summary = event.summary(); + for (const auto &value: summary.value()) { + if (value.value_case() != tensorboard::Summary::Value::kSimpleValue) { + continue; + } + const std::string &tag = value.tag(); + const float scalarValue = value.simple_value(); + ScalarPoint point; + point.step_ = step; + point.value_ = scalarValue; + res[tag].emplace_back(step, scalarValue); + } + return true; +} + diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/TFEventParser.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/TFEventParser.h new file mode 100644 index 0000000000000000000000000000000000000000..7d6525f2adde4a78a5f784b2fdf070e533cec059 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/parser/TFEventParser.h @@ -0,0 +1,34 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef TFEVENTPARSER_H +#define TFEVENTPARSER_H + +#include "FileParser.h" +#include "proto/event.pb.h" +#include "defs/ConceptDefs.h" + +namespace Insight::Scalar::Parser { +class TFEventParser final : public FileParser { +public: + TFEventParser() { + type_ = ParseDataType::TF_EVENT; + } + + ~TFEventParser() override = default; + + bool ReadRecord(std::string &eventStr, std::ifstream &input) override; + + bool ParseRecordToScalar(std::string &&record, std::map> &res) override; + +private: + /** + * @brief check whether contains scalar value + * @param event tf event object + * @return true for success + */ + static bool EventContainsScalar(const tensorboard::Event &event); +}; +} + +#endif //TFEVENTPARSER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetAllGraphHandler.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetAllGraphHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7193e034c3e9eeb8eb42d7be2a40851f0c60448 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetAllGraphHandler.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "ScalarVisuallyGetAllGraphHandler.h" +#include "ScalarVisuallyServer.h" +#include "Util/ScalaryProtocolUtil.h" + +#include + +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif + +using namespace Insight::Scalar; +using namespace Insight::Scalar::Protocol; + +bool ScalarVisuallyGetAllGraphHandler::run(std::string_view data, std::string &resultStr) { + resultStr = GetBasicResponse(); + std::unordered_map > res; + res = ScalarVisuallyServer::Instance().GetAllGraphInfo(); + SetResponse(res, resultStr); + return true; +} + +void ScalarVisuallyGetAllGraphHandler::SetResponse(std::unordered_map > &graphInfoMap, + std::string &result) { + document_t document = ParseJsonToStr(result); + auto &allocator = document.GetAllocator(); + json data(rapidjson::kArrayType); + for (auto &[tag, fileList]: graphInfoMap) { + json graphInfo(rapidjson::kObjectType); + AddJsonMember(graphInfo, "tag", tag, allocator); + json fileListArray(rapidjson::kArrayType); + for (auto &file: fileList) { + json item(rapidjson::kObjectType); + AddJsonMember(item, "path", file, allocator); + AddJsonMember(item, "name", GetReadableFileName(file), allocator); + fileListArray.PushBack(item, allocator); + } + AddJsonMember(graphInfo, "fileList", fileListArray, allocator); + data.PushBack(graphInfo, allocator); + } + AddJsonMember(document["body"], "data", data, allocator); + result = DumpJsonToStr(document); +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetAllGraphHandler.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetAllGraphHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..3e96eb0caba1608cb07b51ee5b50babab5a871d2 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetAllGraphHandler.h @@ -0,0 +1,24 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_PLUGINS_SCALARVISUALLY_SRC_PLUGIN_HANDLER_SCALARVISUALLYGETALLGRAPHHANDLER_H_ +#define BOARD_PLUGINS_SCALARVISUALLY_SRC_PLUGIN_HANDLER_SCALARVISUALLYGETALLGRAPHHANDLER_H_ + +#include "ApiHandler.h" +#include +#include +#include + +using namespace Dic::Core; +namespace Insight::Scalar { +class ScalarVisuallyGetAllGraphHandler : public PostHandler { +public: + bool run(std::string_view data, std::string &resultStr) override; + +private: + static void + SetResponse(std::unordered_map> &graphInfoMap, std::string &result); +}; +} + +#endif //BOARD_PLUGINS_SCALARVISUALLY_SRC_PLUGIN_HANDLER_SCALARVISUALLYGETALLGRAPHHANDLER_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetScalarDataHandler.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetScalarDataHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e2491798c0f3200b7fd6ea184218123608eb67cb --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetScalarDataHandler.cpp @@ -0,0 +1,137 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "ScalarVisuallyGetScalarDataHandler.h" +#include "Logger.h" +#include "ScalarVisuallyServer.h" +#include + +using namespace Insight::Scalar; +using namespace Insight::Scalar::Protocol; + +bool ScalarVisuallyGetScalarDataHandler::run(std::string_view data, std::string &resultStr) { + resultStr = GetBasicResponse(); + std::string errMsg; + GetScalarDataRequest request; + if (auto errCode = ParseRequestFromJson(data, request, errMsg); errCode != ErrCode::OK) { + LOG(LogRank::Error) << "Invalid request, error:" << errMsg; + SetResponseError(errCode, errMsg, resultStr); + return false; + } + + std::vector responseData; + for (const auto &graph: request.data_) { + std::optional grahData = ScalarVisuallyServer::Instance().GetGraphData(graph); + if (!grahData.has_value()) { + LOG(LogRank::Warning) << "No data in graph, tag=" << graph.tag_ << ",file=" << graph.file_ << "offset=" + << graph.offset_; + continue; + } + responseData.emplace_back(std::move(grahData.value())); + } + SetResponse(std::move(responseData), resultStr); + return true; +} + +bool ScalarVisuallyGetScalarDataHandler::CheckParamValid(const document_t &request) { + if (!request.HasMember("graphList") || !request["graphList"].IsArray()) { + return false; + } + return std::all_of(request["graphList"].Begin(), request["graphList"].End(), [](const auto &item) { + if (!item.HasMember("tag") || !item["tag"].IsString()) { + return false; + } + if (!item.HasMember("file") || !item["file"].IsString()) { + return false; + } + if (!item.HasMember("offset") || !item["offset"].IsUint64()) { + return false; + } + if (!item.HasMember("sampleAlgorithm") || !item["sampleAlgorithm"].IsString()) { + return false; + } + if (!item.HasMember("sampleOffset") || !item["sampleOffset"].IsUint64()) { + return false; + } + if (!item.HasMember("sampleWeight") || !item["sampleWeight"].IsNumber()) { + return false; + } + return true; + }); +} + +void ScalarVisuallyGetScalarDataHandler::SetResponse(std::vector &&responseData, + std::string &resultStr) { + document_t document = ParseJsonToStr(resultStr); + auto &allocator = document.GetAllocator(); + json data(rapidjson::kArrayType); + for (auto &graphData: responseData) { + json graph(rapidjson::kObjectType); + AddJsonMember(graph, "tag", graphData.tag_, allocator); + AddJsonMember(graph, "file", graphData.filePath_, allocator); + json graphDatas(rapidjson::kArrayType); + for (auto &point: graphData.graphData_) { + json scalar(rapidjson::kObjectType); + AddJsonMember(scalar, "step", point.step_, allocator); + // process float Value Nan and Inf + if (std::isnan(point.value_) || std::isinf(point.value_)) { + AddJsonMember(scalar, "value", std::to_string(point.value_), allocator); + } else { + AddJsonMember(scalar, "value", point.value_, allocator); + } + graphDatas.PushBack(scalar, allocator); + } + AddJsonMember(graph, "points", graphDatas, allocator); + json sampledDatas(rapidjson::kArrayType); + for (auto &samplePoint: graphData.graphSampledData_) { + json scalar(rapidjson::kObjectType); + AddJsonMember(scalar, "step", samplePoint.step_, allocator); + // process float Value Nan and Inf + if (std::isnan(samplePoint.value_) || std::isinf(samplePoint.value_)) { + AddJsonMember(scalar, "value", std::to_string(samplePoint.value_), allocator); + } else { + AddJsonMember(scalar, "value", samplePoint.value_, allocator); + } + sampledDatas.PushBack(scalar, allocator); + } + AddJsonMember(graph, "sampledPoints", sampledDatas, allocator); + data.PushBack(graph, allocator); + } + AddJsonMember(document["body"], "data", data, allocator); + resultStr = DumpJsonToStr(document); +} + +ErrCode ScalarVisuallyGetScalarDataHandler::ParseRequestFromJson(std::string_view data, + GetScalarDataRequest &request, + std::string &errMsg) { + std::string parseErr; + std::optional document = TryParseJson(data, parseErr); + if (!document.has_value()) { + errMsg = "Invalid request json, err:" + parseErr; + return ErrCode::INVALID_REQUEST_JSON; + } + if (!CheckParamValid(document.value())) { + errMsg = "Invalid request param"; + return ErrCode::REQUEST_INVALID_PARAM; + } + const json &graphList = document.value()["graphList"]; + std::for_each(graphList.Begin(), graphList.End(), [&request](const json &graph) { + SingleGraphReqInfo temp; + temp.tag_ = graph["tag"].GetString(); + temp.file_ = graph["file"].GetString(); + temp.offset_ = graph["offset"].GetUint64(); + temp.sampleAlgorithm_ = graph["sampleAlgorithm"].GetString(); + temp.sampleOffset_ = graph["sampleOffset"].GetUint64(); + temp.sampleWeight_ = graph["sampleWeight"].GetFloat(); + request.data_.emplace_back(std::move(temp)); + }); + return ErrCode::OK; +} + +bool ScalarVisuallyServer::IsFileWatched(std::string &&path) { + return !(fileManager_.GetFileInfo(path) == nullptr); +} + +void ScalarVisuallyServer::GetFileTags(std::set &tags, std::string &path) { + return graphManager_.GetFileTags(path, tags); +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetScalarDataHandler.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetScalarDataHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..09756c18de81e1e333bf0052afb31c6d7a5febc2 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyGetScalarDataHandler.h @@ -0,0 +1,33 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_PLUGINS_SCALARVISUALLY_SRC_PLUGIN_HANDLER_SCALARVISUALLYGETSCALARDATAHANDLER_H_ +#define BOARD_PLUGINS_SCALARVISUALLY_SRC_PLUGIN_HANDLER_SCALARVISUALLYGETSCALARDATAHANDLER_H_ + +#include "ApiHandler.h" +#include "ScalarVisuallyServer.h" +#include "Util/ScalaryProtocolUtil.h" +#include "GraphManager/GraphManager.h" + +namespace Insight::Scalar { +using namespace Dic::Core; +using namespace Scalar::Protocol; +using namespace Scalar::GraphOp; + +class ScalarVisuallyGetScalarDataHandler : public PostHandler { +public: + bool run(std::string_view data, std::string &resultStr) override; + + static bool CheckParamValid(const document_t &request); + +private: + static void SetResponse(std::vector &&responseData, + std::string &resultStr); + + static inline ErrCode ParseRequestFromJson(std::string_view data, + GetScalarDataRequest &request, + std::string &errMsg); +}; +} + +#endif //BOARD_PLUGINS_SCALARVISUALLY_SRC_PLUGIN_HANDLER_SCALARVISUALLYGETSCALARDATAHANDLER_H_ diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyImportFileHandler.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyImportFileHandler.cpp new file mode 100644 index 0000000000000000000000000000000000000000..764bbe5548009a2f872c1a80c9f3e8448e16db8a --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyImportFileHandler.cpp @@ -0,0 +1,201 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "ScalarVisuallyImportFileHandler.h" +#include "ScalarVisuallyServer.h" +#include "Logger.h" +#include "Util/FileUtil.h" +#include "Util/ScalaryProtocolUtil.h" +#include +#include +#include + +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif +using namespace Insight::Scalar; +using namespace Insight; + +bool ScalarVisuallyImportFileHandler::run(std::string_view data, std::string &resultStr) { + resultStr = GetBasicResponse(); + std::string errMsg; + ImportFileRequest request; + if (auto errCode = ParseReqeustFromJson(data, request, errMsg); errCode != ErrCode::OK) { + LOG(LogRank::Error) << errMsg; + SetResponseError(errCode, errMsg, resultStr); + return false; + } + if (!request.append_) { + ScalarVisuallyServer::Instance().Reset(); + } + auto &pathList = request.pathList_; + + // get all file which need to be import + std::vector importFiles = GetImportFiles(pathList); + // delete the invalid path + const auto end = std::remove_if(pathList.begin(), pathList.end(), [](const auto &path) { + return ScalarVisuallyImportFileHandler::PathInvalid(path); + }); + pathList.erase(end); + // import file, record the tags of parsed result + std::set graphTags; + for (auto &file: importFiles) { + HandleImportFile(std::move(file), graphTags); + } + SetResponse(graphTags, resultStr); + return true; +} + +bool ScalarVisuallyImportFileHandler::PathInvalid(std::string_view path) { + if (path.empty()) { + LOG(LogRank::Info) << "path is empty"; + return true; + } + if (!fs::is_regular_file(path) && !fs::is_directory(path)) { + LOG(LogRank::Info) << "path is not a normal file or dir, path=" << path; + return true; + } + return false; +} + +bool ScalarVisuallyImportFileHandler::HandleImportFile(std::string &&path, std::set &graphTags) { + ScalarVisuallyServer &instance = ScalarVisuallyServer::Instance(); + if (instance.IsFileImported(path)) { + LOG(LogRank::Info) << "File already imported before"; + return true; + } + const std::shared_ptr fileInfo = instance.AddFile(path); + const auto parser = instance.GetFileParser(fileInfo->parseDataType_); + if (!parser) { + return false; + } + std::map > parsedData = parser->ParserData(path, fileInfo->offSet_); + if (parsedData.empty()) { + LOG(LogRank::Error) << "File not contains data, file=" << path; + return false; + } + for (auto &[tag, data]: parsedData) { + graphTags.insert(tag); + ScalarVisuallyServer::Instance().UpdateGraph(tag, path, std::move(data)); + } + fileInfo->imported_ = true; + fileInfo->empty_ = false; + ScalarVisuallyServer::Instance().AddFileWatch({path}); + return true; +} + +std::vector ScalarVisuallyImportFileHandler::GetImportFiles(std::vector &pathList) { + std::vector res; + for (const auto &path: pathList) { + if (!fs::is_directory(path) && ScalarVisuallyServer::IsFileSupported(path)) { + res.emplace_back(path); + } + RecursiveScanFolder(path, res, 7); + } + return res; +} + +void ScalarVisuallyImportFileHandler::SetResponse(std::set &graphTags, std::string &resultStr) { + document_t document = ParseJsonToStr(resultStr); + auto &allocator = document.GetAllocator(); + json &body = document["body"]; + json data(rapidjson::kArrayType); + for (auto &tag: graphTags) { + std::vector dataFiles = ScalarVisuallyServer::Instance().GetGraphInfo(tag); + json graph(rapidjson::kObjectType); + AddJsonMember(graph, "tag", tag, allocator); + json fileList(rapidjson::kArrayType); + for (const std::string &dataFile: dataFiles) { + json file(rapidjson::kObjectType); + AddJsonMember(file, "path", dataFile, allocator); + AddJsonMember(file, "name", GetReadableFileName(dataFile), allocator); + fileList.PushBack(file, allocator); + } + AddJsonMember(graph, "fileList", fileList, allocator); + data.PushBack(graph, allocator); + } + AddJsonMember(body, "data", data, allocator); + resultStr = DumpJsonToStr(document); +} + +void ScalarVisuallyImportFileHandler::AddFileWatch(const std::string &path) { + if (path.empty()) { + return; + } + ScalarVisuallyServer::Instance().AddFileWatch({path}); +} + +bool ScalarVisuallyImportFileHandler::CheckParamValid(const document_t &request) { + if (!request.HasMember("append") || !request["append"].IsBool()) { + return false; + } + if (!request.HasMember("pathList") || !request["pathList"].IsArray()) { + return false; + } + return std::all_of(request["pathList"].Begin(), request["pathList"].End(), [](const json &item) { + return item.IsString(); + }); +} + +ErrCode ScalarVisuallyImportFileHandler::ParseReqeustFromJson(std::string_view data, + Insight::Scalar::ImportFileRequest &request, + std::string &errMsg) { + std::string parseErr; + std::optional document = TryParseJson(data, parseErr); + if (!document.has_value()) { + errMsg = "Invalid request json, error:" + parseErr; + return ErrCode::INVALID_REQUEST_JSON; + } + if (!CheckParamValid(document.value())) { + errMsg = "Invalid request param"; + return ErrCode::REQUEST_INVALID_PARAM; + } + auto &requestJson = document.value(); + request.append_ = requestJson["append"].GetBool(); + std::transform(requestJson["pathList"].Begin(), + requestJson["pathList"].End(), + std::back_inserter(request.pathList_), + [](const auto &item) { + return std::string(item.GetString()); + }); + return ErrCode::OK; +} + +void ScalarVisuallyImportFileHandler::RecursiveScanFolder(const std::string &path, + std::vector &fileList, + int maxDepth) { + if (path.empty() || maxDepth < 0) { + return; + } + if (!fs::exists(path) || !fs::is_directory(path)) { + return; + } + std::queue > searchQueue; + searchQueue.push({path, 0}); + while (!searchQueue.empty()) { + auto [curPath, curDepth] = searchQueue.front(); + searchQueue.pop(); + if (curDepth == maxDepth) { + continue; + } + if (auto per = fs::status(curPath).permissions(); (per & fs::perms::owner_read) == fs::perms::none) { + LOG(LogRank::Error) << "Cur path has no read permission"; + continue; + } + for (const auto &entry: fs::directory_iterator(curPath)) { + if (fs::is_directory(entry)) { + searchQueue.emplace(entry.path().string(), curDepth + 1); + continue; + } + if (fs::is_regular_file(entry) && ScalarVisuallyServer::IsFileSupported(entry.path().string())) { + fileList.emplace_back(entry.path().string()); + } + } + } +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyImportFileHandler.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyImportFileHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..81d8b1d03f9361f04b91919c39498a7173de9132 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyImportFileHandler.h @@ -0,0 +1,46 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_SCALARVISUALLYIMPORTFILEHANDLER_H +#define BOARD_SCALARVISUALLYIMPORTFILEHANDLER_H + +#include +#include +#include "ApiHandler.h" +#include "Util/ScalaryProtocolUtil.h" + +using namespace Dic::Core; +namespace Insight::Scalar { +using namespace Insight::Scalar::Protocol; +struct ImportFileRequest { + std::vector pathList_; + bool append_{false}; +}; + +class ScalarVisuallyImportFileHandler : public PostHandler { +public: + bool run(std::string_view data, std::string &resultStr) override; + + static bool CheckParamValid(const document_t &request); + +private: + + static ErrCode ParseReqeustFromJson(std::string_view data, ImportFileRequest &request, std::string &errMsg); + + static bool PathInvalid(std::string_view); + + static std::vector GetImportFiles(std::vector &pathList); + + static bool HandleImportFile(std::string &&path, std::set &graphTags); + + static void SetResponse(std::set &graphTags, std::string &resultStr); + + static void AddFileWatch(const std::string &path); + + static void RecursiveScanFolder(const std::string &path, + std::vector &fileList, + int maxDepth); +}; +} + +#endif //BOARD_SCALARVISUALLYIMPORTFILEHANDLER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyQueryCreatedFile.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyQueryCreatedFile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..382d119d5a33e04a68dac36f1076a9e87541527b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyQueryCreatedFile.cpp @@ -0,0 +1,35 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "ScalarVisuallyQueryCreatedFile.h" +#include "Logger.h" +#include "defs/ConceptDefs.h" +#include "ScalarVisuallyServer.h" +#include "Util/ScalaryProtocolUtil.h" + +using namespace Insight::Scalar; +using namespace Insight::Scalar::Protocol; + +bool ScalarVisuallyQueryCreatedFile::run(std::string_view data, std::string &resultStr) { + resultStr = GetBasicResponse(); + std::unordered_map> + createFileGroupByDir = ScalarVisuallyServer::Instance().GetCreatedFileGroupByDir(); + SetResponse(createFileGroupByDir, resultStr); + return true; +} + +void ScalarVisuallyQueryCreatedFile::SetResponse(std::unordered_map> &createdFileGroupByDir, + std::string &resultStr) { + document_t result = ParseJsonToStr(resultStr); + auto &allocator = result.GetAllocator(); + json data(rapidjson::kArrayType); + for (auto &[dir, fileList]: createdFileGroupByDir) { + json createFile(rapidjson::kObjectType); + AddJsonMember(createFile, "dir", dir, allocator); + AddJsonMember(createFile, "fileList", fileList, allocator); + data.PushBack(createFile, allocator); + } + AddJsonMember(result["body"], "data", data, allocator); + resultStr = DumpJsonToStr(result); +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyQueryCreatedFile.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyQueryCreatedFile.h new file mode 100644 index 0000000000000000000000000000000000000000..9bf13746ef1f05b5a05574da324eaf8e73e79a6d --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/Handler/ScalarVisuallyQueryCreatedFile.h @@ -0,0 +1,24 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef SCALARVISUALLYQUERYNEWFILE_H +#define SCALARVISUALLYQUERYNEWFILE_H + +#include "ApiHandler.h" +#include +#include + +namespace Insight::Scalar { +using namespace Dic::Core; + +class ScalarVisuallyQueryCreatedFile : public PostHandler { +public: + bool run(std::string_view data, std::string &resultStr) override; + +private: + static void SetResponse(std::unordered_map> &createdFileGroupByDir, + std::string &resultStr); +}; +} + +#endif //SCALARVISUALLYQUERYNEWFILE_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisualPlugin.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisualPlugin.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9a3bac8a4ca677b9e1bf4df4d465bdccabdfa4a3 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisualPlugin.cpp @@ -0,0 +1,45 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "ScalarVisualPlugin.h" +#include "PluginsManager.h" +#include "Handler/ScalarVisuallyImportFileHandler.h" +#include "Handler/ScalarVisuallyGetAllGraphHandler.h" +#include "Handler/ScalarVisuallyGetScalarDataHandler.h" +#include "Handler/ScalarVisuallyQueryCreatedFile.h" + +using namespace Insight; +namespace Insight::Scalar { +ScalarVisualPlugin::ScalarVisualPlugin() : Dic::Core::BasePlugin("ScalarVisually") { + handlers_.emplace("ImportFile", std::make_shared()); + handlers_.emplace("GetAllGraph", std::make_shared()); + handlers_.emplace("GetScalarData", std::make_shared()); + handlers_.emplace("GetAddFiles", std::make_shared()); +} + +std::map> ScalarVisualPlugin::GetAllHandlers() { + std::map> res(handlers_.begin(), handlers_.end()); + return res; +} + +std::vector ScalarVisualPlugin::GetModuleConfig() { + std::vector res; + document_t moduleConfig; + moduleConfig.SetObject(); + auto &allocator = moduleConfig.GetAllocator(); + AddJsonMember(moduleConfig, "name", "Scalar", allocator); + AddJsonMember(moduleConfig, "requestName", "scalar", allocator); + json_t attributes(rapidjson::kObjectType); + AddJsonMember(attributes, "src", "./plugins/Scalar/index.html", allocator); + AddJsonMember(moduleConfig, "attributes", attributes, allocator); + AddJsonMember(moduleConfig, "isDefault", true, allocator); + AddJsonMember(moduleConfig, "isCluster", true, allocator); + AddJsonMember(moduleConfig, "isCompute", true, allocator); + AddJsonMember(moduleConfig, "isJupyter", true, allocator); + res.push_back(DumpJsonToStr(moduleConfig)); + return res; +} + +} + +PluginRegister pluginRegister(std::move(std::make_unique())); diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisualPlugin.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisualPlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..5da781dfe414244ce2609c4bd92950098f298651 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisualPlugin.h @@ -0,0 +1,28 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_SCALARVISUALPLUGIN_H +#define BOARD_SCALARVISUALPLUGIN_H + +#include "BasePlugin.h" + +namespace Insight::Scalar { +using namespace Dic::Core; + +class ScalarVisualPlugin : public BasePlugin { +public: + ScalarVisualPlugin(); + + ~ScalarVisualPlugin() override = default; + + std::map> GetAllHandlers() override; + + std::vector GetModuleConfig() override; + + +private: + std::map> handlers_; +}; +} + +#endif //BOARD_SCALARVISUALPLUGIN_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisuallyServer.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisuallyServer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d79fe24252806e2770f415e6710d00e706849c1 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisuallyServer.cpp @@ -0,0 +1,96 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#include "ScalarVisuallyServer.h" +#include "Logger.h" +#include "FileManager/FileInfoManager.h" +#include + +using namespace Insight::Scalar; + +bool ScalarVisuallyServer::IsFileImported(std::string_view path) { + auto fileInfo = fileManager_.GetFileInfo(path); + if (fileInfo == nullptr) { + return false; + } + if (!fileInfo->imported_) { + return false; + } + return true; +} + +void ScalarVisuallyServer::OnFileDataUpdate(std::string &&dir, std::string &&fileName) { + std::string filePath = dir + "/" + fileName; + std::shared_ptr file = fileManager_.GetFileInfo(filePath); + if (file == nullptr) { + return; + } + auto parser = GetFileParser(file->parseDataType_); + if (!parser) { + return; + } + std::map> data = parser->ParserData(filePath, file->offSet_); + for (auto &[tag, points]: data) { + graphManager_.UpdateGraphData(tag, filePath, std::move(points)); + } + if (!data.empty() && file->empty_) { + ScalarVisuallyServer::Instance().OnFileCreate(std::move(dir), std::move(fileName)); + file->empty_ = false; + } +} + +bool ScalarVisuallyServer::IsFileSupported(std::string_view path) { + return FileInfo::FileInfoManager::IsFileSupported(path); +} + +std::shared_ptr ScalarVisuallyServer::AddFile(const std::string &file) { + ParseDataType type = FileInfo::FileInfoManager::GetFileType(file); + return fileManager_.AddFile(file, type); +} + +void ScalarVisuallyServer::OnFileCreate(std::string &&dir, std::string &&fileName) { + fileManager_.OnFileCreate(std::move(dir), std::move(fileName)); +} + +std::shared_ptr ScalarVisuallyServer::GetFileParser(ParseDataType type) { + return Parser::ParserFactory::Instance().CreateFileParse(type); +} + +void ScalarVisuallyServer::Reset() { + fileManager_.Reset(); + graphManager_.Reset(); + fileWatcher_->Reset(); +} + +void +ScalarVisuallyServer::UpdateGraph(const std::string &tag, const std::string &path, std::vector &&data) { + graphManager_.UpdateGraphData(tag, path, std::move(data)); +} + +std::vector ScalarVisuallyServer::GetGraphInfo(const std::string &tag) { + std::vector res; + auto graph = graphManager_.GetGraph(tag); + if (graph == nullptr) { + LOG(LogRank::Error) << "Graph not exit, tag=" << tag; + return res; + } + return graph->GetDataFiles(); +} + +std::unordered_map> ScalarVisuallyServer::GetAllGraphInfo() { + return graphManager_.GetAllGraphInfo(); +} + +std::optional ScalarVisuallyServer::GetGraphData(const SingleGraphReqInfo &reqInfo) { + return graphManager_.GetGraphData(reqInfo); +} + +void ScalarVisuallyServer::AddFileWatch(const std::vector &fileList) { + + fileWatcher_->AddWatchPath(fileList); +} + +std::unordered_map> ScalarVisuallyServer::GetCreatedFileGroupByDir() { + return fileManager_.GetCreatedFileGroupByDir(); +} + diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisuallyServer.h b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisuallyServer.h new file mode 100644 index 0000000000000000000000000000000000000000..5b542f50322a25ad697f207ed60c7aab6374b043 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/plugin/ScalarVisuallyServer.h @@ -0,0 +1,82 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. +*/ +#ifndef BOARD_SCALARVISUALLYSERVER_H +#define BOARD_SCALARVISUALLYSERVER_H + +#include +#include +#include "defs/ConceptDefs.h" +#include "FileManager/FileInfoManager.h" +#include "GraphManager/GraphManager.h" +#include "FileManager/FileWatcherFactory.h" +#include "ParserFactory.h" +#include "Logger.h" +#include "Util/ScalaryProtocolUtil.h" + +using namespace Insight; +namespace Insight::Scalar { +using namespace GraphOp; +using namespace Protocol; + +class ScalarVisuallyServer { +public: + static ScalarVisuallyServer &Instance() { + static ScalarVisuallyServer instance; + return instance; + } + + ~ScalarVisuallyServer() = default; + + std::shared_ptr AddFile(const std::string &file); + + bool IsFileImported(std::string_view path); + + void OnFileDataUpdate(std::string &&dir, std::string &&fileName); + + void OnFileCreate(std::string &&dir, std::string &&fileName); + + std::shared_ptr GetFileParser(ParseDataType type); + + /** + * @brief check whether support parse the file + * @param path + * @return + */ + static bool IsFileSupported(std::string_view path); + + void Reset(); + + void UpdateGraph(const std::string &tag, const std::string &path, std::vector &&data); + + std::vector GetGraphInfo(const std::string &tag); + + std::unordered_map > GetAllGraphInfo(); + + std::optional GetGraphData(const SingleGraphReqInfo &reqInfo); + + void AddFileWatch(const std::vector &fileList); + + std::unordered_map > GetCreatedFileGroupByDir(); + + bool IsFileWatched(std::string &&path); + + void GetFileTags(std::set &tags, std::string &path); + +private: + ScalarVisuallyServer() { + fileWatcher_ = FileWatch::FileWatcherFactory::GetFileWatcher(); + if (!fileWatcher_) { + LOG(LogRank::Warning) << "This platform not support file realtime watch"; + return; + } + fileWatcher_->Init(); + }; + + Scalar::FileInfo::FileInfoManager fileManager_{}; + GraphManager graphManager_{}; + std::unique_ptr fileWatcher_{nullptr}; +}; +} + +#endif //BOARD_SCALARVISUALLYSERVER_H diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/test/FileInfoManagerTest.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/FileInfoManagerTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffd310f3494d7a2382eef13081bf0c42a0950a73 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/FileInfoManagerTest.cpp @@ -0,0 +1,43 @@ +#include "gtest/gtest.h" +#include "FileManager/FileInfoManager.h" + +using namespace Insight::Scalar; +using namespace Insight::Scalar::FileInfo; + +class FileManagerTestSuit : public testing::Test { +protected: +}; + +TEST_F(FileManagerTestSuit, TestAddFile +) { + FileInfoManager manager; + auto fsp1 = manager.AddFile("/root/test.a", ParseDataType::TF_EVENT); + EXPECT_EQ(manager + .GetFileInfo("/root/test.a") == fsp1, true); + EXPECT_EQ(fsp1 + ->filePath_ == "/root/test.a", true); + EXPECT_EQ(fsp1 + ->offSet_, 0); + EXPECT_EQ(fsp1 + ->parseDataType_, ParseDataType::TF_EVENT); +} + +TEST_F(FileManagerTestSuit, TestRepeatAddFile +) { + FileInfoManager manager; + auto fsp1 = manager.AddFile("/root/test.a", ParseDataType::TF_EVENT); + auto fsp2 = manager.AddFile("/root/test.a", ParseDataType::TF_EVENT); + EXPECT_EQ(fsp1, fsp2 + ); + EXPECT_EQ(fsp1 + == nullptr, false); +} + +TEST_F(FileManagerTestSuit, DelFile +) { + FileInfoManager manager; + manager.AddFile("/root/test.a", ParseDataType::TF_EVENT); + manager.DelFile("/root/test.a"); + EXPECT_EQ(manager + .GetFileInfo("/root/test.a") == nullptr, true); +} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/test/GraphTest.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/GraphTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..00ac1aadf4eb31c41d0eb415d852c70e983a41c3 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/GraphTest.cpp @@ -0,0 +1,129 @@ +#include "gtest/gtest.h" +#include "GraphManager/GraphManager.h" + +using namespace Insight::Scalar; +using namespace Insight::Scalar::Graph; + +class GraphTestSuit : public testing::Test { +protected: + std::vector lossData = {{0, 0.158}, + {1, 0.11124}, + {2, 0.3124}}; + std::vector lineData = {{0, 0.18}, + {1, 0.24}, + {2, 0.424}}; +}; + +TEST_F(GraphTestSuit, AddGraph +) { + GraphManager manager; + std::vector temp(lossData.begin(), lossData.end()); + manager.UpdateGraphData("loss", "TestA", + std::move(temp) + ); + auto graphPtr = manager.GetGraph("loss"); + EXPECT_NE(graphPtr, + nullptr); + EXPECT_EQ(graphPtr + -> + GetDataFiles() + . + size(), + 1); +} + +TEST_F(GraphTestSuit, GetGraphData +) { + GraphManager manager; + std::vector temp(lossData.begin(), lossData.end()); + manager.UpdateGraphData("loss", "TestA", + std::move(temp) + ); + auto data = manager.GetGraphData(<#initializer#>); + EXPECT_EQ(data + . + has_value(), + true); + EXPECT_EQ(data + . + value() + .filePath_, "TestA"); + EXPECT_EQ(data + . + value() + .tag_, "loss"); + EXPECT_EQ(data + . + value() + .graphData_. + size(), + 3); + EXPECT_EQ(data + . + value() + .graphData_[0].step_, 0); + EXPECT_FLOAT_EQ(data + . + value() + .graphData_[0].value_, 0.158); + auto data2 = manager.GetGraphData(<#initializer#>); + EXPECT_EQ(data2 + . + has_value(), + true); + EXPECT_EQ(data2 + . + value() + .graphData_. + size(), + 2); + EXPECT_EQ(data2 + . + value() + .graphData_[0].step_, 1); + EXPECT_FLOAT_EQ(data2 + . + value() + .graphData_[0].value_, 0.11124); +} + +TEST_F(GraphTestSuit, GetGraphInfo +) { + GraphManager manager; + std::vector temp(lossData.begin(), lossData.end()); + manager.UpdateGraphData("loss", "TestA", + std::move(temp) + ); + std::copy(lossData + . + begin(), lossData + . + begin() + + 1, + std::back_inserter(temp) + ); + manager.UpdateGraphData("loss", "TestB", + std::move(temp) + ); + std::vector tem2(lineData.begin(), lineData.end()); + manager.UpdateGraphData("line", "TestC", + std::move(tem2) + ); + auto graphMap = manager.GetAllGraphInfo(); + EXPECT_EQ(graphMap + .count("loss"), 1); + EXPECT_EQ(graphMap + .count("line"), 1); + EXPECT_EQ(graphMap["loss"] + . + size(), + 2); + EXPECT_EQ(graphMap["loss"][0], + "TestA"); + EXPECT_EQ(graphMap["line"] + . + size(), + 1); + EXPECT_EQ(graphMap["line"][0], + "TestC"); +} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/test/ParserTest.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/ParserTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..20575f7976411c3996bf5a4e54a2a13e0375b87d --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/ParserTest.cpp @@ -0,0 +1,51 @@ +#include "gtest/gtest.h" +#include "ParserFactory.h" +#include "proto/event.pb.h" +#include "defs/ConceptDefs.h" + +using namespace Insight::Scalar::Parser; +using namespace Insight::Scalar; + +class ParserTestSuit : public ::testing::Test { +protected: + std::string pathPrefix = "../../../../"; + std::string tfeventTestFile = pathPrefix + + + "mindstudio-board/server/plugins/ScalarVisually/test/test_data/tfevent_data/events.out.tfevents.1728628561"; + + void SetUp() override { + } + +}; + +TEST_F(ParserTestSuit, ParserFactoryGetParserInstance) { + std::shared_ptr parser = nullptr; + parser = ParserFactory::Instance().CreateFileParse(ParseDataType::Unknown); + EXPECT_EQ(parser, nullptr); + parser = ParserFactory::Instance().CreateFileParse(ParseDataType::TF_EVENT); + EXPECT_EQ(parser->type_ == ParseDataType::TF_EVENT, true); + parser = ParserFactory::Instance().CreateFileParse(ParseDataType::MindSpore_Summary); + EXPECT_EQ(parser->type_ == ParseDataType::MindSpore_Summary, true); + parser = ParserFactory::Instance().CreateFileParse(ParseDataType::TEXT_LOG); + EXPECT_EQ(parser->type_ == ParseDataType::TEXT_LOG, true); + parser = ParserFactory::Instance().CreateFileParse(ParseDataType::Unknown); + EXPECT_EQ(parser, nullptr); +} + +TEST_F(ParserTestSuit, ParseTFevent) { + + auto parser = ParserFactory::Instance().CreateFileParse(ParseDataType::TF_EVENT); + uint64_t offset = 0; + auto res = parser->ParserData(tfeventTestFile, offset); + EXPECT_EQ(res.empty(), false); + EXPECT_EQ(offset, 18446744073709551615ull); + EXPECT_EQ(res.size(), 1); + EXPECT_EQ(res.count("Loss/train"), 1); + std::vector &datas = res.at("Loss/train"); + EXPECT_EQ(datas[0].step_, 0); + EXPECT_FLOAT_EQ(datas[0].value_, 0.136831999); +} + +TEST_F(ParserTestSuit, ParseSummaryData) { + auto parser = ParserFactory::Instance().CreateFileParse(ParseDataType::MindSpore_Summary); +} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/test/PluginTest.cpp b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/PluginTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3e8f0cea41e19dc30d50d3ab39f3675779aeaf7 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/PluginTest.cpp @@ -0,0 +1,177 @@ +#include "gtest/gtest.h" +#include "plugin/ScalarVisualPlugin.h" +#include "ApiHandler.h" +#include "Util/ScalaryProtocolUtil.h" +#include "ScalarVisuallyServer.h" +#include + +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +#else + +#include + +namespace fs = std::experimental::filesystem; +#endif +using namespace Insight::Scalar; +using namespace Dic::Core; +using namespace Insight::Scalar::Protocol; + +class PluginTestSuit : public testing::Test { +protected: + ScalarVisualPlugin plugin_; + + void SetUpTestSuit() { + + }; + + const std::shared_ptr &GetHandler(std::string_view name) { + auto handlerMap = plugin_.GetAllHandlers(); + return handlerMap.at(std::string(name)); + } + + std::string GetBasicResult() { + return R"({"body":{}, "msg":"", "errCode":0, "result":true})"; + } + + void ImportFile(std::string &file) { + auto &imporFileHandler = GetHandler("ImportFile"); + std::string request = R"({"append": false, "pathList": [")" + file + R"("]})"; + std::string resultStr = GetBasicResult(); + bool success = imporFileHandler->run(request, resultStr); + } + + std::string currentPath = fs::current_path().string(); + size_t index = currentPath.find("mindstudio-board"); + std::string pathPreFix = currentPath.substr(0, index); + std::string tfDataDir = pathPreFix + "mindstudio-board/server/plugins/ScalarVisually/test/test_data/tfevent_data"; + std::string tfDataFile = + pathPreFix + + "mindstudio-board/server/plugins/ScalarVisually/test/test_data/tfevent_data/events.out.tfevents.1728628561"; + ScalarVisuallyServer &server = ScalarVisuallyServer::Instance(); + +}; + +TEST_F(PluginTestSuit, ImportFileWithDir) { + auto &imporFileHandler = GetHandler("/ScalarVisually/ImportFile"); + std::string request = R"({"append": false, "pathList": [")" + tfDataDir + R"(]})"; + std::string resultStr = GetBasicResult(); + bool success = imporFileHandler->run(request, resultStr); + document_t result = ParseJsonToStr(resultStr); + json &body = result["body"]; + EXPECT_EQ(success, true); + EXPECT_EQ(body.HasMember("data"), true); + json &data = body["data"]; + EXPECT_EQ(data.IsArray(), true); + EXPECT_EQ(data.Size(), 1); + EXPECT_EQ(data[0].HasMember("tag"), true); + EXPECT_STREQ(data[0]["tag"].GetString(), "Loss/train"); +} + +TEST_F(PluginTestSuit, ImportFileWithFile) { + auto &imporFileHandler = GetHandler("/ScalarVisually/ImportFile"); + std::string request = R"({"append": false, "pathList": [")" + tfDataFile + R"("]})"; + std::string resultStr = GetBasicResult(); + bool success = imporFileHandler->run(request, resultStr); + document_t result = ParseJsonToStr(resultStr); + json &body = result["body"]; + EXPECT_EQ(success, true); + EXPECT_EQ(body.HasMember("data"), true); + json &data = body["data"]; + EXPECT_EQ(data.IsArray(), true); + EXPECT_EQ(data.Size(), 1); + EXPECT_EQ(data[0].HasMember("tag"), true); + EXPECT_STREQ(data[0]["tag"].GetString(), "Loss/train"); +} + +TEST_F(PluginTestSuit, GetLossDataOffSetZero) { + ImportFile(tfDataFile); + auto &GetDataHandler = GetHandler("/ScalarVisually/GetScalarData"); + std::string request = R"({"graphList": [{"tag": "Loss/train", "file": ")" + tfDataFile + R"(", "offset": 0}]})"; + std::string response = GetBasicResult(); + bool succes = GetDataHandler->run(request, response); + EXPECT_EQ(succes, true); + document_t result = ParseJsonToStr(response); + json &body = result["body"]; + json &data = body["data"]; + EXPECT_EQ(data.IsArray(), true); + json &lossData = data[0]; + EXPECT_STREQ(lossData["tag"].GetString(), "Loss/train"); + json &points = lossData["points"]; + EXPECT_EQ(points.IsArray(), true); + EXPECT_EQ(points.Size(), 10000); +} + +TEST_F(PluginTestSuit, GetLossDataOffSet) { + ImportFile(tfDataFile); + auto &GetDataHandler = GetHandler("/ScalarVisually/GetScalarData"); + std::string request = R"({"graphList": [{"tag": "Loss/train", "file": ")" + tfDataFile + R"(", "offset": 10}]})"; + std::string response = GetBasicResult(); + bool succes = GetDataHandler->run(request, response); + EXPECT_EQ(succes, true); + document_t result = ParseJsonToStr(response); + json &body = result["body"]; + json &data = body["data"]; + EXPECT_EQ(data.IsArray(), true); + json &lossData = data[0]; + EXPECT_STREQ(lossData["tag"].GetString(), "Loss/train"); + json &points = lossData["points"]; + EXPECT_EQ(points.IsArray(), true); + EXPECT_EQ(points.Size(), 9990); +} + +TEST_F(PluginTestSuit, GetLossDataOffSetExceedSize) { + ImportFile(tfDataFile); + auto &GetDataHandler = GetHandler("/ScalarVisually/GetScalarData"); + std::string request = R"({"graphList": [{"tag": "Loss/train", "file": ")" + tfDataFile + R"(", "offset": 10010}]})"; + std::string response = GetBasicResult(); + bool succes = GetDataHandler->run(request, response); + EXPECT_EQ(succes, true); + document_t result = ParseJsonToStr(response); + json &body = result["body"]; + json &data = body["data"]; + EXPECT_EQ(data.IsArray(), true); + json &lossData = data[0]; + EXPECT_STREQ(lossData["tag"].GetString(), "Loss/train"); + json &points = lossData["points"]; + EXPECT_EQ(points.IsArray(), true); + EXPECT_EQ(points.Size(), 0); +} + +TEST_F(PluginTestSuit, GetAllGraph) { + ImportFile(tfDataFile); + auto &handler = GetHandler("/ScalarVisually/GetAllGraph"); + std::string request = R"()"; + std::string response = GetBasicResult(); + bool success = handler->run(request, response); + EXPECT_EQ(success, true); + document_t result = ParseJsonToStr(response); + json &body = result["body"]; + json &data = body["data"]; + EXPECT_EQ(data.IsArray(), true); + EXPECT_EQ(data.Size(), 1); + json &graph = data[0]; + EXPECT_STREQ(graph["tag"].GetString(), "Loss/train"); + EXPECT_EQ(graph["fileList"].IsArray() && graph["fileList"].Size() == 1, true); + json &file = graph["fileList"][0]; + EXPECT_EQ(file.IsString(), true); + EXPECT_EQ(tfDataFile.compare(file.GetString()) == 0, true); +} + +TEST_F(PluginTestSuit, ImportNewFile) { + /* + { + "pathList":[ + ], + "append":true + } + */ + server.OnFileCreate(std::string(tfDataDir), fs::path(tfDataFile).filename().string()); + std::string request = R"({"pathList":[")" + tfDataFile + R"("], "append":true})"; + std::string result = GetBasicResult(); + auto handler = GetHandler("ImportFile"); + auto errCode = handler->run(request, result); + EXPECT_EQ(errCode, 0); + +} diff --git a/plugins/mindstudio-insight-plugins/Scalar/server/src/test/test_data/tfevent_data/events.out.tfevents.1728628561 b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/test_data/tfevent_data/events.out.tfevents.1728628561 new file mode 100644 index 0000000000000000000000000000000000000000..5ccfa7c4727bb72c4288eeab37b7c7a0c364de7f Binary files /dev/null and b/plugins/mindstudio-insight-plugins/Scalar/server/src/test/test_data/tfevent_data/events.out.tfevents.1728628561 differ diff --git a/plugins/mindstudio-insight-plugins/plugin_core/CMakeLists.txt b/plugins/mindstudio-insight-plugins/plugin_core/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f96a0daa821615cf48b19b0d4a869e91f731ce1c --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/CMakeLists.txt @@ -0,0 +1,20 @@ +project(msinsight) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +aux_source_directory(./src CORE_SRC) +set(LIBRARY_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output/lib) +add_library(${PROJECT_NAME} SHARED ${CORE_SRC}) +target_include_directories(${PROJECT_NAME} PUBLIC ./include) +target_include_directories(${PROJECT_NAME} PRIVATE ./src) +target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/rapidjson/include/rapidjson) +target_compile_definitions(${PROJECT_NAME} PRIVATE PLUGINS_DIR="${PROJECT_ROOT_DIR}/output/plugins") +if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_libraries(${PROJECT_NAME} PUBLIC dl) + target_link_libraries(${PROJECT_NAME} PUBLIC stdc++fs) +endif () + +if (${CMAKE_BUILD_TYPE} MATCHES "Debug") + target_compile_options(${PROJECT_NAME} PRIVATE -g -O0) +endif () + + diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/gtest.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/gtest.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d0a1069a3040df432ad9f38342a60b5f1a88ab38 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/gtest.cmake @@ -0,0 +1,60 @@ +set(gtest_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") +set(gtest_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") + +set(CMAKE_OPTION + -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON + -DCMAKE_MACOSX_RPATH=TRUE) +if(BUILD_LITE) + if(PLATFORM_ARM64 AND CMAKE_SYSTEM_NAME MATCHES "Android") + set(CMAKE_OPTION -DCMAKE_TOOLCHAIN_FILE=$ENV{ANDROID_NDK}/build/cmake/android.toolchain.cmake + -DANDROID_NATIVE_API_LEVEL=19 + -DANDROID_NDK=$ENV{ANDROID_NDK} + -DANDROID_ABI=arm64-v8a + -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang + -DANDROID_STL=${ANDROID_STL} + ${CMAKE_OPTION}) + endif() + if(PLATFORM_ARM32 AND CMAKE_SYSTEM_NAME MATCHES "Android") + set(CMAKE_OPTION -DCMAKE_TOOLCHAIN_FILE=$ENV{ANDROID_NDK}/build/cmake/android.toolchain.cmake + -DANDROID_NATIVE_API_LEVEL=19 + -DANDROID_NDK=$ENV{ANDROID_NDK} + -DANDROID_ABI=armeabi-v7a + -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang + -DANDROID_STL=${ANDROID_STL} + ${CMAKE_OPTION}) + endif() +endif() + +if(NOT ENABLE_GLIBCXX) + set(gtest_CXXFLAGS "${gtest_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") +endif() + +if(BUILD_LITE) + if(ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.1.tar.gz") + set(SHA256 "9bf1fe5182a604b4135edc1a425ae356c9ad15e9b23f9f12a02e80184c3a249c") + else() + set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.1.tar.gz") + set(SHA256 "9bf1fe5182a604b4135edc1a425ae356c9ad15e9b23f9f12a02e80184c3a249c") + endif() + + insight_add_pkg(gtest + VER 1.8.1 + DOWNLOAD_ONLY ON + URL ${REQ_URL} + SHA256 ${SHA256}) +else() + if(ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.12.1.tar.gz") + set(SHA256 "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2") + else() + set(REQ_URL "https://github.com/google/googletest/archive/release-1.12.1.tar.gz") + set(SHA256 "81964fe578e9bd7c94dfdb09c8e4d6e6759e19967e397dbea48d1c10e45d0df2") + endif() + + insight_add_pkg(gtest + VER 1.12.1 + DOWNLOAD_ONLY ON + URL ${REQ_URL} + SHA256 ${SHA256}) +endif() \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/json.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/json.cmake new file mode 100644 index 0000000000000000000000000000000000000000..47dda916140afd084d21ecaf501a7c6977ae7b59 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/json.cmake @@ -0,0 +1,32 @@ +if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. + set(GIT_REPOSITORY "http://gitee.com/Tencent/RapidJSON.git") + set(TAG "6089180ecb704cb2b136777798fa1be303618975") + set(INCLUDE "./include/rapidjson") +else() + set(GIT_REPOSITORY "http://github.com/Tencent/rapidjson.git") + set(TAG "6089180ecb704cb2b136777798fa1be303618975") + set(INCLUDE "./include/rapidjson") +endif() + + +set(ENABLE_NATIVE_JSON "off") +if(EXISTS ${TOP_DIR}/mindspore/lite/providers/json/native_json.cfg) + set(ENABLE_NATIVE_JSON "on") +endif() +if(ENABLE_NATIVE_JSON) + file(STRINGS ${TOP_DIR}/mindspore/lite/providers/json/native_json.cfg native_json_path) + insight_add_pkg(rapidjson + GIT_TAG ${TAG} + HEAD_ONLY ${INCLUDE} + DIR ${native_json_path}) + add_library(mindspore::json ALIAS rapidjson) +else() + insight_add_pkg(rapidjson + GIT_TAG ${TAG} + HEAD_ONLY ${INCLUDE} + GIT_REPOSITORY ${GIT_REPOSITORY} + SHA256 ${SHA256} + ) + include_directories(${rapidjson_INC}) + add_library(mindspore::json ALIAS rapidjson) +endif() \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/libuv.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/libuv.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3d7044033e4bfb183d6d806d6faea03a0bfe2665 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/libuv.cmake @@ -0,0 +1,21 @@ +if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. + set(REQ_URL "https://gitee.com/mirrors/libuv/repository/archive/v1.48.0.tar.gz") + set(SHA256 "9b833b426922e2eb568631aa11bdac526e556627b1cbfbc196e4674adc3d43af") + set(VER "1.48.0") +else() + set(REQ_URL "https://github.com/libuv/libuv/archive/refs/tags/v1.48.0.tar.gz") + set(SHA256 "8c253adb0f800926a6cbd1c6576abae0bc8eb86a4f891049b72f9e5b7dc58f33") + set(VER "1.48.0") +endif() + +add_compile_definitions("LIBUS_NO_SSL") +add_compile_definitions("UWS_NO_ZLIB") +add_compile_definitions("LIBUS_USE_LIBUV") + +insight_add_pkg(libuv + VER ${VER} + DOWNLOAD_ONLY ON + URL ${REQ_URL} + SHA256 ${SHA256}) + +set_target_properties(uv PROPERTIES EXCLUDE_FROM_ALL true) \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/protobuf.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/protobuf.cmake new file mode 100644 index 0000000000000000000000000000000000000000..ab6cd01c6f2cc34352f996d0b3224aefcaf9b001 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/protobuf.cmake @@ -0,0 +1,160 @@ +set(protobuf_USE_STATIC_LIBS ON) +set(ENABLE_NATIVE_PROTOBUF "off") +if(EXISTS ${TOP_DIR}/mindspore/lite/providers/protobuf/native_protobuf.cfg) + set(ENABLE_NATIVE_PROTOBUF "on") + file(STRINGS ${TOP_DIR}/mindspore/lite/providers/protobuffer/native_protobuffer.cfg native_protobuffer_path) +endif() +if(BUILD_LITE) + if(MSVC) + set(protobuf_CXXFLAGS "${CMAKE_CXX_FLAGS}") + set(protobuf_CFLAGS "${CMAKE_C_FLAGS}") + set(protobuf_LDFLAGS "${CMAKE_SHARED_LINKER_FLAGS}") + set(_ms_tmp_CMAKE_STATIC_LIBRARY_PREFIX ${CMAKE_STATIC_LIBRARY_PREFIX}) + set(CMAKE_STATIC_LIBRARY_PREFIX "lib") + if(DEBUG_MODE) + set(protobuf_Debug ON) + endif() + else() + set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \ + -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") + if(NOT ENABLE_GLIBCXX) + set(protobuf_CXXFLAGS "${protobuf_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") + endif() + set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") + endif() +else() + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-uninitialized -Wno-unused-parameter -fPIC \ + -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") + elseif(${CMAKE_SYSTEM_NAME} MATCHES "Windows") + if(MSVC) + set(protobuf_CXXFLAGS "/DWIN32 /D_WINDOWS /W3 /GR /EHsc") + set(protobuf_CFLAGS "${CMAKE_C_FLAGS}") + set(protobuf_LDFLAGS "${CMAKE_SHARED_LINKER_FLAGS}") + set(_ms_tmp_CMAKE_STATIC_LIBRARY_PREFIX ${CMAKE_STATIC_LIBRARY_PREFIX}) + set(CMAKE_STATIC_LIBRARY_PREFIX "lib") + if(DEBUG_MODE) + set(protobuf_Debug ON) + endif() + else() + set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \ + -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") + endif() + else() + set(protobuf_CXXFLAGS "-fstack-protector-all -Wno-maybe-uninitialized -Wno-unused-parameter \ + -fPIC -fvisibility=hidden -D_FORTIFY_SOURCE=2 -O2") + if(NOT ENABLE_GLIBCXX) + set(protobuf_CXXFLAGS "${protobuf_CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") + endif() + set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") + endif() +endif() + +set(_ms_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS}) +string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") +string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + +if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz") + set(SHA256 "ab9b39e7053a6fb06b01bf75fb6ec6a71a1ada5a5f8e2446f927336e97b9e7bb") +else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz") + set(SHA256 "9b4ee22c250fe31b16f1a24d61467e40780a3fbb9b91c3b65be2a376ed913a1a") +endif() + +if(BUILD_LITE) + set(PROTOBUF_PATCH_ROOT ${TOP_DIR}/third_party/patch/protobuf) +else() + set(PROTOBUF_PATCH_ROOT ${CMAKE_SOURCE_DIR}/third_party/patch/protobuf) +endif() + +if(MSVC) + insight_add_pkg(protobuf + VER 3.13.0 + LIBS protobuf + EXE protoc + URL ${REQ_URL} + SHA256 ${SHA256} + CMAKE_PATH cmake/ + CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release + -Dprotobuf_MSVC_STATIC_RUNTIME=OFF) +else() + insight_add_pkg(protobuf + VER 3.13.0 + LIBS protobuf + EXE protoc + URL ${REQ_URL} + SHA256 ${SHA256} + CMAKE_PATH cmake/ + CMAKE_OPTION -Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release) +endif() +include_directories(${protobuf_INC}) +add_library(mindboard::protobuf ALIAS protobuf::protobuf) +set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS}) +# recover original value +if(MSVC) + set(CMAKE_STATIC_LIBRARY_PREFIX, ${_ms_tmp_CMAKE_STATIC_LIBRARY_PREFIX}) +endif() + +if(ENABLE_NATIVE_PROTOBUF) + set(PROTOC ${native_protobuffer_path}/bin/protoc) + set(PROTOBUF_LIB ${native_protobuffer_path}/lib/libprotobuf.so.3.13.0.0) + set(protobuf_LIBPATH ${native_protobuffer_path}/lib) + set(protobuf_INC ${native_protobuffer_path}/include) + + include_directories(${protobuf_INC}) + message("protobuf_INC : ${protobuf_INC}") + set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS}) +endif() +add_custom_target(PROTO_GENERATE) +function(common_protobuf_generate target path c_var h_var) + if(NOT ARGN) + message(SEND_ERROR "Error: ms_protobuf_generate() called without any proto files") + return() + endif() + + set(${c_var}) + set(${h_var}) + + foreach(file ${ARGN}) + get_filename_component(abs_file ${file} ABSOLUTE) + get_filename_component(file_name ${file} NAME_WE) + get_filename_component(file_dir ${abs_file} PATH) + file(RELATIVE_PATH rel_path ${CMAKE_CURRENT_SOURCE_DIR} ${file_dir}) + + list(APPEND ${c_var} "${path}/${file_name}.pb.cc") + list(APPEND ${h_var} "${path}/${file_name}.pb.h") + if(ENABLE_NATIVE_PROTOBUF) + add_custom_command( + TARGET ${target} + #OUTPUT "${path}/${file_name}.pb.cc" "${path}/${file_name}.pb.h" + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory "${path}" + COMMAND ${CMAKE_COMMAND} -E env "LD_LIBRARY_PATH=${protobuf_LIBPATH}" ${PROTOC} -I${file_dir} + --cpp_out=${path} ${abs_file} + DEPENDS ${PROTOC} ${abs_file} + COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM) + else() + get_target_property(PROTOC_EXEC protobuf::protoc IMPORT_LOCATION) + execute_process( + RESULT_VARIABLE RETURN_CODE + COMMAND ${PROTOC_EXEC} -I${file_dir} --cpp_out=${path} ${abs_file} + OUTPUT_VARIABLE STDOUT + ERROR_VARIABLE STDERR + COMMAND_ERROR_IS_FATAL ANY + ) + message(STATUS "Running C++ protocol buffer compiler on ${file}") + endif() + endforeach() + + set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) + set(${c_var} ${${c_var}} PARENT_SCOPE) + set(${h_var} ${${h_var}} PARENT_SCOPE) +endfunction() + +function(ms_protobuf_generate target path c_var h_var) + common_protobuf_generate(${target} ${path} ${c_var} ${h_var} ${ARGN}) + set(${c_var} ${${c_var}} PARENT_SCOPE) + set(${h_var} ${${h_var}} PARENT_SCOPE) +endfunction() diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/uSockets.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/uSockets.cmake new file mode 100644 index 0000000000000000000000000000000000000000..e45927d5bc2b2ceee737f1e20960697e0a502740 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/uSockets.cmake @@ -0,0 +1,17 @@ +set(REQ_URL "https://github.com/uNetworking/uSockets/archive/refs/tags/v0.8.6.tar.gz") +set(SHA256 "16eba133dd33eade2f5f8dd87612c04b5dd711066e0471c60d641a2f6a988f16") +set(VER "0.8.6") + +insight_add_pkg(uSockets + VER ${VER} + HEAD_ONLY src + URL ${REQ_URL} + SHA256 ${SHA256}) + +list(APPEND MAIN_INCLUDE ${usockets_INC}) + +aux_source_directory(${usockets_INC} U_SOCKETS_SRC) +aux_source_directory(${usockets_INC}/crypto U_SOCKETS_SRC) +aux_source_directory(${usockets_INC}/eventing U_SOCKETS_SRC) +aux_source_directory(${usockets_INC}/internal U_SOCKETS_SRC) + diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/uWebSockets.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/uWebSockets.cmake new file mode 100644 index 0000000000000000000000000000000000000000..43b258e4b34f8bef3433a42ab81eb901fd2805da --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/external_libs/uWebSockets.cmake @@ -0,0 +1,19 @@ +if(ENABLE_GITEE OR ENABLE_GITEE_EULER) # Channel GITEE_EULER is NOT supported now, use GITEE instead. + set(REQ_URL "https://gitee.com/mirrors/uWebSockets/repository/archive/v20.48.0.tar.gz") + set(SHA256 "7992d0c8b11e2ec4d32cc38e8880b0b5c3115d9e3c3a0988c96df14d88b958df") + set(VER "20.48.0") +else() + set(REQ_URL "https://github.com/uNetworking/uWebSockets/archive/refs/tags/v20.48.0.tar.gz") + set(SHA256 "d7455bbbf9829b3960d0478dd36ed0eba82847c4fc801416aaf89ccb7f4dfb85") + set(VER "20.48.0") +endif() + +insight_add_pkg(uWebSockets + VER ${VER} + HEAD_ONLY src + URL ${REQ_URL} + SHA256 ${SHA256}) + +list(APPEND MAIN_INCLUDE ${uwebsockets_INC}) + + diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/mind_expression.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/mind_expression.cmake new file mode 100644 index 0000000000000000000000000000000000000000..408878e61dfb542f8807e0298fb32e5419fecc2e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/mind_expression.cmake @@ -0,0 +1,31 @@ +set(SECURE_CXX_FLAGS "") +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if(WIN32) + set(SECURE_CXX_FLAGS "-fstack-protector-all") + else() + set(SECURE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") + endif() +endif() +set(_ms_tmp_CMAKE_CXX_FLAGS_F ${CMAKE_CXX_FLAGS}) + +if(NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") +endif() + +set(TOP_DIR ${CMAKE_SOURCE_DIR}) +set(ENABLE_GITEE ON) + +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/options.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/utils.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/external_libs/protobuf.cmake) +#include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/external_libs/json.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/external_libs/libuv.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/external_libs/uSockets.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/external_libs/uWebSockets.cmake) + + +if(ENABLE_TESTCASES OR ENABLE_CPP_ST) + include(${CMAKE_CURRENT_SOURCE_DIR}/plugin_core/cmake/external_libs/gtest.cmake) +endif() + +set(CMAKE_CXX_FLAGS ${_ms_tmp_CMAKE_CXX_FLAGS_F}) diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/options.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/options.cmake new file mode 100644 index 0000000000000000000000000000000000000000..661f6357b3f03d7a95d23304751c95de85a1e9f7 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/options.cmake @@ -0,0 +1,69 @@ +option(ENABLE_TESTCASES "Run testcases switch, default off" ON) +option(ENABLE_CPP_ST "Run cpp st testcases switch, default off" ON) +option(DEBUG_MODE "Debug mode, default off" OFF) +option(ENABLE_COVERAGE "Enable code coverage report" OFF) +option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs") +option(ENABLE_DEBUGGER "enable debugger" OFF) +option(ENABLE_GLIBCXX "enable_glibcxx" ON) +option(BUILD_DEV_MODE "MindBoard build nightly dev mode" OFF) +option(USE_LLVM "Use llvm" OFF) + +if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux") + set(ENABLE_GLIBCXX ON) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(WIN32) + set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fstack-protector-all") + else() + set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") + endif() +endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -Wsign-compare") +endif() + +if(ENABLE_COVERAGE) + set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -ftest-coverage") + set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} ${COVERAGE_COMPILER_FLAGS}") +endif() + +if(ENABLE_ASAN) + set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -fsanitize=address -fsanitize-recover=address -fno-omit-frame-pointer") + if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(OPTION_CXX_FLAGS "${OPTION_CXX_FLAGS} -static-libsan") + endif() + add_compile_definitions(ENABLE_ASAN) +endif() + +if(DEBUG_MODE) + set(CMAKE_BUILD_TYPE "Debug") + add_compile_definitions(MEM_REUSE_DEBUG) +else() + set(CMAKE_BUILD_TYPE "Release") +endif() + +if((CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") OR (CMAKE_BUILD_TYPE STREQUAL Release)) + set(PYBIND11_LTO_CXX_FLAGS FALSE) +endif() + +if(NOT BUILD_PATH) + set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") +endif() + +if(ENABLE_TESTCASES) + add_compile_definitions(ENABLE_TEST) +endif() + +if(ENABLE_DEBUGGER) + add_compile_definitions(ENABLE_DEBUGGER) +endif() + +if(ENABLE_DEBUGGER OR ENABLE_TESTCASES) + set(MS_BUILD_GRPC ON) +endif() + +if(USE_LLVM) + add_compile_definitions(USE_LLVM) +endif() diff --git a/plugins/mindstudio-insight-plugins/plugin_core/cmake/utils.cmake b/plugins/mindstudio-insight-plugins/plugin_core/cmake/utils.cmake new file mode 100644 index 0000000000000000000000000000000000000000..42862dece0292d04871538dec1354d4076edd278 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/cmake/utils.cmake @@ -0,0 +1,584 @@ +include(FetchContent) +set(FETCHCONTENT_QUIET OFF) + +if(CMAKE_SYSTEM_NAME MATCHES "Windows" AND ${CMAKE_VERSION} VERSION_GREATER_EQUAL 3.17.0) + set(CMAKE_FIND_LIBRARY_SUFFIXES .dll ${CMAKE_FIND_LIBRARY_SUFFIXES}) +endif() + +function(insight_add_submodule_obj des_submodule_objs sub_dir submodule_name_obj) + + add_subdirectory(${sub_dir}) + + if(NOT TARGET ${submodule_name_obj}) + message(FATAL_ERROR "Can not find submodule '${submodule_name_obj}'. in ${CMAKE_CURRENT_LIST_FILE}") + endif() + if("$" IN_LIST ${des_submodule_objs}) + message(FATAL_ERROR "submodule '${submodule_name_obj}' added more than once. in ${CMAKE_CURRENT_LIST_FILE}") + endif() + + set(${des_submodule_objs} ${${des_submodule_objs}} $ PARENT_SCOPE) + +endfunction() + +if(DEFINED ENV{MSLIBS_CACHE_PATH}) + set(_MS_LIB_CACHE $ENV{MSLIBS_CACHE_PATH}) +else() + set(_MS_LIB_CACHE ${CMAKE_BINARY_DIR}/.mslib) +endif() +message("MS LIBS CACHE PATH: ${_MS_LIB_CACHE}") + +if(NOT EXISTS ${_MS_LIB_CACHE}) + file(MAKE_DIRECTORY ${_MS_LIB_CACHE}) +endif() + +if(DEFINED ENV{MSLIBS_SERVER} AND NOT ENABLE_GITEE) + set(LOCAL_LIBS_SERVER $ENV{MSLIBS_SERVER}) + message("LOCAL_LIBS_SERVER: ${LOCAL_LIBS_SERVER}") +endif() + +include(ProcessorCount) +ProcessorCount(N) +if(JOBS) + set(THNUM ${JOBS}) +else() + set(JOBS 8) + if(${JOBS} GREATER ${N}) + set(THNUM ${N}) + else() + set(THNUM ${JOBS}) + endif() +endif() +message("set make thread num: ${THNUM}") + +if(LOCAL_LIBS_SERVER) + if(NOT ENV{no_proxy}) + set(ENV{no_proxy} "${LOCAL_LIBS_SERVER}") + else() + string(FIND $ENV{no_proxy} ${LOCAL_LIBS_SERVER} IP_POS) + if(${IP_POS} EQUAL -1) + set(ENV{no_proxy} "$ENV{no_proxy},${LOCAL_LIBS_SERVER}") + endif() + endif() +endif() + +function(__download_pkg pkg_name pkg_url pkg_sha256) + + if(LOCAL_LIBS_SERVER) + set(REGEX_IP_ADDRESS "^([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)$") + get_filename_component(_URL_FILE_NAME ${pkg_url} NAME) + if(${LOCAL_LIBS_SERVER} MATCHES ${REGEX_IP_ADDRESS}) + set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url}) + else() + set(pkg_url "https://${LOCAL_LIBS_SERVER}/libs/${pkg_name}/${_URL_FILE_NAME}" ${pkg_url}) + endif() + endif() + + FetchContent_Declare( + ${pkg_name} + URL ${pkg_url} + URL_HASH SHA256=${pkg_sha256} + ) + FetchContent_GetProperties(${pkg_name}) + message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") + FetchContent_MakeAvailable(${pkg_name}) + set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) + +endfunction() + +function(__download_pkg_with_git pkg_name pkg_url pkg_git_commit pkg_sha256) + + if(LOCAL_LIBS_SERVER) + set(pkg_url "http://${LOCAL_LIBS_SERVER}:8081/libs/${pkg_name}/${pkg_git_commit}") + FetchContent_Declare( + ${pkg_name} + URL ${pkg_url} + URL_HASH SHA256=${pkg_sha256} + ) + else() + FetchContent_Declare( + ${pkg_name} + GIT_REPOSITORY ${pkg_url} + GIT_TAG ${pkg_git_commit}) + endif() + FetchContent_GetProperties(${pkg_name}) + message("download: ${${pkg_name}_SOURCE_DIR} , ${pkg_name} , ${pkg_url}") + FetchContent_MakeAvailable(${pkg_name}) + set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) + +endfunction() + + +function(__find_pkg_then_add_target pkg_name pkg_exe lib_path) + set(options) + set(oneValueArgs PATH) + set(multiValueArgs SUFFIXES_PATH NAMES) + cmake_parse_arguments(LIB "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + unset(${pkg_name}_LIBS) + + message("_FIND:${${pkg_name}_BASE_DIR}") + + if(pkg_exe) + unset(${pkg_exe}_EXE CACHE) + find_program(${pkg_exe}_EXE ${pkg_exe} PATHS ${${pkg_name}_BASE_DIR}/bin NO_DEFAULT_PATH) + if(NOT ${pkg_exe}_EXE) + return() + endif() + add_executable(${pkg_name}::${pkg_exe} IMPORTED GLOBAL) + set_target_properties(${pkg_name}::${pkg_exe} PROPERTIES + IMPORTED_LOCATION ${${pkg_exe}_EXE} + ) + message("found ${${pkg_exe}_EXE}") + endif() + + foreach(_LIB_NAME ${LIB_NAMES}) + set(_LIB_SEARCH_NAME ${_LIB_NAME}) + if(MSVC AND ${pkg_name}_Debug) + set(_LIB_SEARCH_NAME ${_LIB_SEARCH_NAME}d) + endif() + set(_LIB_TYPE SHARED) + if(${pkg_name}_USE_STATIC_LIBS) + set(_LIB_SEARCH_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${_LIB_SEARCH_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(_LIB_TYPE STATIC) + endif() + set(${_LIB_NAME}_LIB ${_LIB_NAME}_LIB-NOTFOUND) + if(APPLE) + find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/${lib_path} + PATH_SUFFIXES ${LIB_SUFFIXES_PATH} NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + else() + find_library(${_LIB_NAME}_LIB ${_LIB_SEARCH_NAME} PATHS ${${pkg_name}_BASE_DIR}/${lib_path} + PATH_SUFFIXES ${LIB_SUFFIXES_PATH} NO_DEFAULT_PATH) + endif() + if(NOT ${_LIB_NAME}_LIB) + message("not find ${_LIB_SEARCH_NAME} in path: ${${pkg_name}_BASE_DIR}/${lib_path}") + return() + endif() + + add_library(${pkg_name}::${_LIB_NAME} ${_LIB_TYPE} IMPORTED GLOBAL) + if(WIN32 AND ${_LIB_TYPE} STREQUAL "SHARED") + if(DEBUG_MODE) + set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES IMPORTED_IMPLIB_DEBUG ${${_LIB_NAME}_LIB}) + else() + set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES IMPORTED_IMPLIB_RELEASE ${${_LIB_NAME}_LIB}) + endif() + else() + set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES IMPORTED_LOCATION ${${_LIB_NAME}_LIB}) + endif() + + if(EXISTS ${${pkg_name}_BASE_DIR}/include) + set_target_properties(${pkg_name}::${_LIB_NAME} PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${${pkg_name}_BASE_DIR}/include") + endif() + + list(APPEND ${pkg_name}_LIBS ${pkg_name}::${_LIB_NAME}) + message("found ${${_LIB_NAME}_LIB}") + STRING(REGEX REPLACE "(.+)/(.+)" "\\1" LIBPATH ${${_LIB_NAME}_LIB}) + set(${pkg_name}_LIBPATH ${LIBPATH} CACHE STRING INTERNAL) + endforeach() + + set(${pkg_name}_LIBS ${${pkg_name}_LIBS} PARENT_SCOPE) +endfunction() + +function(__exec_cmd) + set(options) + set(oneValueArgs WORKING_DIRECTORY) + set(multiValueArgs COMMAND) + + cmake_parse_arguments(EXEC "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + execute_process(COMMAND ${EXEC_COMMAND} + WORKING_DIRECTORY ${EXEC_WORKING_DIRECTORY} + RESULT_VARIABLE RESULT) + if(NOT RESULT EQUAL "0") + message(FATAL_ERROR "error! when ${EXEC_COMMAND} in ${EXEC_WORKING_DIRECTORY}") + endif() +endfunction() + +function(__check_patches pkg_patches) + # check patches + if(PKG_PATCHES) + file(TOUCH ${_MS_LIB_CACHE}/${pkg_name}_patch.sha256) + file(READ ${_MS_LIB_CACHE}/${pkg_name}_patch.sha256 ${pkg_name}_PATCHES_SHA256) + + message("patches sha256:${${pkg_name}_PATCHES_SHA256}") + + set(${pkg_name}_PATCHES_NEW_SHA256) + foreach(_PATCH ${PKG_PATCHES}) + file(SHA256 ${_PATCH} _PF_SHA256) + set(${pkg_name}_PATCHES_NEW_SHA256 "${${pkg_name}_PATCHES_NEW_SHA256},${_PF_SHA256}") + endforeach() + + if(NOT ${pkg_name}_PATCHES_SHA256 STREQUAL ${pkg_name}_PATCHES_NEW_SHA256) + set(${pkg_name}_PATCHES ${PKG_PATCHES}) + file(REMOVE_RECURSE "${_MS_LIB_CACHE}/${pkg_name}-subbuild") + file(WRITE ${_MS_LIB_CACHE}/${pkg_name}_patch.sha256 ${${pkg_name}_PATCHES_NEW_SHA256}) + message("patches changed : ${${pkg_name}_PATCHES_NEW_SHA256}") + endif() + endif() +endfunction() + +set(MS_FIND_NO_DEFAULT_PATH NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH NO_SYSTEM_ENVIRONMENT_PATH + NO_CMAKE_BUILDS_PATH NO_CMAKE_PACKAGE_REGISTRY NO_CMAKE_SYSTEM_PATH + NO_CMAKE_SYSTEM_PACKAGE_REGISTRY) +function(insight_add_pkg pkg_name) + + set(options) + set(oneValueArgs URL SHA256 GIT_REPOSITORY GIT_TAG VER EXE DIR DOWNLOAD_ONLY HEAD_ONLY CMAKE_PATH RELEASE + LIB_PATH CUSTOM_CMAKE) + set(multiValueArgs + CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS + INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES ONLY_MAKE ONLY_MAKE_INCS ONLY_MAKE_LIBS + LIB_SUFFIXES_PATH) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if(NOT PKG_LIB_PATH) + set(PKG_LIB_PATH lib) + endif() + + if(NOT PKG_EXE) + set(PKG_EXE 0) + endif() + + set(__FIND_PKG_NAME ${pkg_name}) + string(TOLOWER ${pkg_name} pkg_name) + message("pkg name:${__FIND_PKG_NAME},${pkg_name}") + + set(${pkg_name}_PATCHES_HASH) + foreach(_PATCH ${PKG_PATCHES}) + file(SHA256 ${_PATCH} _PF_SHA256) + set(${pkg_name}_PATCHES_HASH "${${pkg_name}_PATCHES_HASH},${_PF_SHA256}") + endforeach() + + # strip directory variables to ensure third party packages are installed in consistent locations + string(REPLACE ${TOP_DIR} "" ARGN_STRIPPED ${ARGN}) + string(REPLACE ${_MS_LIB_CACHE} "" ARGN_STRIPPED ${ARGN_STRIPPED}) + # check options + set(${pkg_name}_CONFIG_TXT + "${CMAKE_CXX_COMPILER_VERSION}-${CMAKE_C_COMPILER_VERSION} + ${ARGN_STRIPPED}-${${pkg_name}_USE_STATIC_LIBS}-${${pkg_name}_PATCHES_HASH} + ${${pkg_name}_CXXFLAGS}-${${pkg_name}_CFLAGS}-${${pkg_name}_LDFLAGS}") + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(${pkg_name}_CONFIG_TXT "${${pkg_name}_CONFIG_TXT}--${CMAKE_OSX_DEPLOYMENT_TARGET}") + endif() + string(REPLACE ";" "-" ${pkg_name}_CONFIG_TXT ${${pkg_name}_CONFIG_TXT}) + string(SHA256 ${pkg_name}_CONFIG_HASH ${${pkg_name}_CONFIG_TXT}) + + message("${pkg_name} config hash: ${${pkg_name}_CONFIG_HASH}") + + set(${pkg_name}_BASE_DIR ${_MS_LIB_CACHE}/${pkg_name}_${PKG_VER}_${${pkg_name}_CONFIG_HASH}) + set(${pkg_name}_DIRPATH ${${pkg_name}_BASE_DIR} CACHE STRING INTERNAL) + + if(EXISTS ${${pkg_name}_BASE_DIR}/options.txt AND PKG_HEAD_ONLY) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) + add_library(${pkg_name} INTERFACE) + target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) + if(${PKG_RELEASE}) + __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIB_PATH} + SUFFIXES_PATH ${PKG_LIB_SUFFIXES_PATH} + NAMES ${PKG_LIBS}) + endif() + return() + endif() + + set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR}) + set(${__FIND_PKG_NAME}_ROOT ${${pkg_name}_BASE_DIR} PARENT_SCOPE) + + if(PKG_LIBS) + __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIB_PATH} + SUFFIXES_PATH ${PKG_LIB_SUFFIXES_PATH} + NAMES ${PKG_LIBS}) + if(${pkg_name}_LIBS) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + message("Found libs: ${${pkg_name}_LIBS}") + return() + endif() + elseif(NOT PKG_HEAD_ONLY) + find_package(${__FIND_PKG_NAME} ${PKG_VER} PATHS ${${pkg_name}_BASE_DIR} ${MS_FIND_NO_DEFAULT_PATH}) + if(${__FIND_PKG_NAME}_FOUND) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + message("Found pkg: ${__FIND_PKG_NAME}") + return() + endif() + endif() + + if(NOT PKG_DIR) + if(PKG_GIT_REPOSITORY) + __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_SHA256}) + else() + __download_pkg(${pkg_name} ${PKG_URL} ${PKG_SHA256}) + endif() + foreach(_SUBMODULE_FILE ${PKG_SUBMODULES}) + STRING(REGEX REPLACE "(.+)_(.+)" "\\1" _SUBMODEPATH ${_SUBMODULE_FILE}) + STRING(REGEX REPLACE "(.+)/(.+)" "\\2" _SUBMODENAME ${_SUBMODEPATH}) + file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*) + file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME}) + endforeach() + else() + set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) + endif() + file(WRITE ${${pkg_name}_BASE_DIR}/options.txt ${${pkg_name}_CONFIG_TXT}) + message("${pkg_name}_SOURCE_DIR : ${${pkg_name}_SOURCE_DIR}") + + foreach(_PATCH_FILE ${PKG_PATCHES}) + get_filename_component(_PATCH_FILE_NAME ${_PATCH_FILE} NAME) + + # convert line-endings of patch file to UNIX LF + set(_LF_PATCH_FILE ${CMAKE_BINARY_DIR}/_ms_patch/${_PATCH_FILE_NAME}) + configure_file(${_PATCH_FILE} ${_LF_PATCH_FILE} NEWLINE_STYLE LF @ONLY) + + # convert line-endings of source file to be patched to UNIX LF + file(READ ${_LF_PATCH_FILE} _LF_PATCH_CONTENT) + string(REGEX MATCHALL "diff --git a/[/A-Za-z0-9\.\-_]*" _PATCH_SOURCE_LIST "${_LF_PATCH_CONTENT}") + list(TRANSFORM _PATCH_SOURCE_LIST REPLACE "diff --git a/" "") # strip prefix of file path + + foreach(_PATCH_SOURCE ${_PATCH_SOURCE_LIST}) + if(EXISTS ${${pkg_name}_SOURCE_DIR}/${_PATCH_SOURCE}) + execute_process(COMMAND bash -c "sed -i \'s@\\r@@g\' ${${pkg_name}_SOURCE_DIR}/${_PATCH_SOURCE}" + COMMAND_ECHO STDOUT) + endif() + endforeach() + + # apply patch + message("patching ${${pkg_name}_SOURCE_DIR} -p1 < ${_LF_PATCH_FILE}") + execute_process(COMMAND ${Patch_EXECUTABLE} -p1 INPUT_FILE ${_LF_PATCH_FILE} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR} + RESULT_VARIABLE Result) + if(NOT Result EQUAL "0") + message(FATAL_ERROR "Failed patch: ${_LF_PATCH_FILE}") + endif() + endforeach() + foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) + file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) + file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) + endforeach() + file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) + if(NOT ${pkg_name}_LOCK_RET EQUAL "0") + message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") + endif() + + if(PKG_CUSTOM_CMAKE) + file(GLOB ${pkg_name}_cmake ${PKG_CUSTOM_CMAKE}/CMakeLists.txt) + file(COPY ${${pkg_name}_cmake} DESTINATION ${${pkg_name}_SOURCE_DIR}) + endif() + + if(${pkg_name}_SOURCE_DIR) + if (PKG_DOWNLOAD_ONLY) + file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) + file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR}) + set(${pkg_name}_SOURCE_DIR ${${pkg_name}_SOURCE_DIR} PARENT_SCOPE) + elseif(PKG_HEAD_ONLY) + file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) + file(COPY ${${pkg_name}_SOURCE_SUBDIRS} DESTINATION ${${pkg_name}_BASE_DIR}) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/${PKG_HEAD_ONLY} PARENT_SCOPE) + if(NOT PKG_RELEASE) + add_library(${pkg_name} INTERFACE) + target_include_directories(${pkg_name} INTERFACE ${${pkg_name}_INC}) + endif() + + elseif(PKG_ONLY_MAKE) + __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_CXXFLAGS} -j${THNUM} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + set(PKG_INSTALL_INCS ${PKG_ONLY_MAKE_INCS}) + set(PKG_INSTALL_LIBS ${PKG_ONLY_MAKE_LIBS}) + file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) + file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) + file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) + file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) + + elseif(PKG_CMAKE_OPTION) + # in cmake + file(MAKE_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + if(${pkg_name}_CFLAGS) + set(${pkg_name}_CMAKE_CFLAGS "-DCMAKE_C_FLAGS=${${pkg_name}_CFLAGS}") + endif() + if(${pkg_name}_CXXFLAGS) + set(${pkg_name}_CMAKE_CXXFLAGS "-DCMAKE_CXX_FLAGS=${${pkg_name}_CXXFLAGS}") + endif() + + if(${pkg_name}_LDFLAGS) + if(${pkg_name}_USE_STATIC_LIBS) + #set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_STATIC_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") + else() + set(${pkg_name}_CMAKE_LDFLAGS "-DCMAKE_SHARED_LINKER_FLAGS=${${pkg_name}_LDFLAGS}") + endif() + endif() + if(APPLE) + __exec_cmd(COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_COMPILER_ARG1=${CMAKE_CXX_COMPILER_ARG1} + -DCMAKE_C_COMPILER_ARG1=${CMAKE_C_COMPILER_ARG1} ${PKG_CMAKE_OPTION} + ${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS} + -DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ${${pkg_name}_SOURCE_DIR}/${PKG_CMAKE_PATH} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + else() + __exec_cmd(COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_COMPILER_ARG1=${CMAKE_CXX_COMPILER_ARG1} + -DCMAKE_C_COMPILER_ARG1=${CMAKE_C_COMPILER_ARG1} ${PKG_CMAKE_OPTION} -G ${CMAKE_GENERATOR} + ${${pkg_name}_CMAKE_CFLAGS} ${${pkg_name}_CMAKE_CXXFLAGS} ${${pkg_name}_CMAKE_LDFLAGS} + -DCMAKE_INSTALL_PREFIX=${${pkg_name}_BASE_DIR} ${${pkg_name}_SOURCE_DIR}/${PKG_CMAKE_PATH} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + if(MSVC) + set(CONFIG_TYPE Release) + if(DEBUG_MODE) + set(CONFIG_TYPE Debug) + endif() + __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --config ${CONFIG_TYPE} --target install -- + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + else() + __exec_cmd(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j${THNUM} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}/_build) + endif() + endif() + else() + if(${pkg_name}_CFLAGS) + set(${pkg_name}_MAKE_CFLAGS "CFLAGS=${${pkg_name}_CFLAGS}") + endif() + if(${pkg_name}_CXXFLAGS) + set(${pkg_name}_MAKE_CXXFLAGS "CXXFLAGS=${${pkg_name}_CXXFLAGS}") + endif() + if(${pkg_name}_LDFLAGS) + set(${pkg_name}_MAKE_LDFLAGS "LDFLAGS=${${pkg_name}_LDFLAGS}") + endif() + # in configure && make + if(PKG_PRE_CONFIGURE_COMMAND) + __exec_cmd(COMMAND ${PKG_PRE_CONFIGURE_COMMAND} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif() + + if(PKG_CONFIGURE_COMMAND) + __exec_cmd(COMMAND ${PKG_CONFIGURE_COMMAND} + ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS} + --prefix=${${pkg_name}_BASE_DIR} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif() + set(${pkg_name}_BUILD_OPTION ${PKG_BUILD_OPTION}) + if(NOT PKG_CONFIGURE_COMMAND) + set(${pkg_name}_BUILD_OPTION ${${pkg_name}_BUILD_OPTION} + ${${pkg_name}_MAKE_CFLAGS} ${${pkg_name}_MAKE_CXXFLAGS} ${${pkg_name}_MAKE_LDFLAGS}) + endif() + # build + if(APPLE) + __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + else() + __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} ${${pkg_name}_BUILD_OPTION} -j${THNUM} + WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif() + + if(PKG_INSTALL_INCS OR PKG_INSTALL_LIBS) + file(GLOB ${pkg_name}_INSTALL_INCS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_INCS}) + file(GLOB ${pkg_name}_INSTALL_LIBS ${${pkg_name}_SOURCE_DIR}/${PKG_INSTALL_LIBS}) + file(COPY ${${pkg_name}_INSTALL_INCS} DESTINATION ${${pkg_name}_BASE_DIR}/include) + file(COPY ${${pkg_name}_INSTALL_LIBS} DESTINATION ${${pkg_name}_BASE_DIR}/lib) + else() + __exec_cmd(COMMAND ${CMAKE_MAKE_PROGRAM} install WORKING_DIRECTORY ${${pkg_name}_SOURCE_DIR}) + endif() + endif() + endif() + + if(PKG_LIBS) + __find_pkg_then_add_target(${pkg_name} ${PKG_EXE} ${PKG_LIB_PATH} + SUFFIXES_PATH ${PKG_LIB_SUFFIXES_PATH} + NAMES ${PKG_LIBS}) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + if(NOT ${pkg_name}_LIBS) + message(FATAL_ERROR "Can not find pkg: ${pkg_name}") + endif() + else() + find_package(${__FIND_PKG_NAME} ${PKG_VER} QUIET ${MS_FIND_NO_DEFAULT_PATH}) + if(${__FIND_PKG_NAME}_FOUND) + set(${pkg_name}_INC ${${pkg_name}_BASE_DIR}/include PARENT_SCOPE) + message("Found pkg: ${${__FIND_PKG_NAME}_LIBRARIES}") + return() + endif() + endif() +endfunction() + +function(src_separate_compile) + set(options) + set(oneValueArgs OBJECT_NAME OBJECT_SIZE) + set(multiValueArgs SRC_LIST) + cmake_parse_arguments(STUDENT "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + list(LENGTH STUDENT_SRC_LIST SRC_LIST_SIZE) + + set(SEPARATE_SIZE 100) + set(SEPARATE_INDEX 0) + set(OBJECT_COUNT 0) + math(EXPR SRC_LIST_MAX_INDEX "${SRC_LIST_SIZE} - 1") + while(${SRC_LIST_SIZE} GREATER ${SEPARATE_INDEX}) + math(EXPR SEPARATE_END "${SEPARATE_INDEX} + ${SEPARATE_SIZE} - 1") + if(${SEPARATE_END} GREATER ${SRC_LIST_MAX_INDEX}) + math(EXPR SEPARATE_SIZE "${SRC_LIST_SIZE} - ${SEPARATE_INDEX}") + endif() + list(SUBLIST STUDENT_SRC_LIST ${SEPARATE_INDEX} ${SEPARATE_SIZE} new_sub_list) + math(EXPR OBJECT_COUNT "${OBJECT_COUNT} + 1") + math(EXPR SEPARATE_INDEX "${SEPARATE_INDEX} + ${SEPARATE_SIZE}") + add_library(${STUDENT_OBJECT_NAME}_${OBJECT_COUNT} OBJECT ${new_sub_list}) + endwhile() + set(${STUDENT_OBJECT_SIZE} "${OBJECT_COUNT}" PARENT_SCOPE) + message("${STUDENT_OBJECT_SIZE} object count is ${OBJECT_COUNT}") +endfunction() + +function(enable_target_when_only_build_plugins target) + if(ONLY_BUILD_DEVICE_PLUGINS) + get_target_property(target_type ${target} TYPE) + if(target_type STREQUAL "INTERFACE_LIBRARY") + return() + endif() + set_target_properties(${target} PROPERTIES EXCLUDE_FROM_ALL FALSE) + endif() +endfunction() + +function(disable_target_when_only_build_plugins target) + if(ONLY_BUILD_DEVICE_PLUGINS) + get_target_property(target_type ${target} TYPE) + if(target_type STREQUAL "INTERFACE_LIBRARY") + return() + endif() + get_property(is_set TARGET ${target} PROPERTY EXCLUDE_FROM_ALL) + if(NOT DEFINED is_set) + set_target_properties(${target} PROPERTIES EXCLUDE_FROM_ALL TRUE) + endif() + endif() +endfunction() + +function(enable_directory_when_only_build_plugins dir) + get_property(targets DIRECTORY ${dir} PROPERTY BUILDSYSTEM_TARGETS) + foreach(target ${targets}) + enable_target_when_only_build_plugins(${target}) + endforeach() + get_property(items DIRECTORY ${dir} PROPERTY SUBDIRECTORIES) + foreach(item ${items}) + enable_directory_when_only_build_plugins(${item}) + endforeach() +endfunction() + +function(disable_directory_when_only_build_plugins dir) + get_property(targets DIRECTORY ${dir} PROPERTY BUILDSYSTEM_TARGETS) + foreach(target ${targets}) + disable_target_when_only_build_plugins(${target}) + endforeach() + get_property(items DIRECTORY ${dir} PROPERTY SUBDIRECTORIES) + foreach(item ${items}) + disable_directory_when_only_build_plugins(${item}) + endforeach() +endfunction() + +function(add_subdirectory_with_faster_option dir) + if(ONLY_BUILD_DEVICE_PLUGINS) + add_subdirectory(${dir}) + disable_directory_when_only_build_plugins(${dir}) + else() + add_subdirectory(${dir}) + endif() +endfunction() + +function(find_and_use_mold) + find_program(MOLD_LINKER mold) + if(MOLD_LINKER) + message(STATUS "using mold to speed linking libraries") + get_filename_component(MOLD_LINKER_PATH ${MOLD_LINKER} DIRECTORY) + file(GLOB MOLD_LINKER_PATH "${MOLD_LINKER_PATH}/../libexec/mold") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -B${MOLD_LINKER_PATH}") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -B${MOLD_LINKER_PATH}") + endif() +endfunction() \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/ApiHandler.h b/plugins/mindstudio-insight-plugins/plugin_core/include/ApiHandler.h new file mode 100644 index 0000000000000000000000000000000000000000..14ab8d7efa22317a3f2c8924ebd7f2b92b270ac9 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/ApiHandler.h @@ -0,0 +1,37 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MSINSIGHT_APIHANDLER_H +#define MSINSIGHT_APIHANDLER_H +#include "string" + +namespace Dic::Core { +enum class API_TYPE { + GET, + POST +}; +class ApiHandler { +public: + ApiHandler(API_TYPE type): apiType_(type) {}; + virtual ~ApiHandler() = default; + virtual bool run(std::string_view data, std::string &result) = 0; + API_TYPE GetApiType() { return apiType_; }; +protected: + API_TYPE apiType_; +}; + +class GetHandler : public ApiHandler { +public: + GetHandler(): ApiHandler(API_TYPE::GET) {}; + virtual ~GetHandler() = default; + virtual bool run(std::string_view data, std::string &result) = 0; +}; + +class PostHandler : public ApiHandler { +public: + PostHandler(): ApiHandler(API_TYPE::POST) {}; + virtual ~PostHandler() = default; + virtual bool run(std::string_view data, std::string &result) = 0; +}; +} +#endif // MSINSIGHT_APIHANDLER_H diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/BaseModule.h b/plugins/mindstudio-insight-plugins/plugin_core/include/BaseModule.h new file mode 100644 index 0000000000000000000000000000000000000000..8bb72705140573dd0769e9efc79ae733eba68268 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/BaseModule.h @@ -0,0 +1,26 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MSINSIGHT_BASEMODULE_H_ +#define MSINSIGHT_BASEMODULE_H_ +#include "ProtocolUtil.h" +#include "ModuleRequestHeadler.h" + +namespace Dic::Module { +class BaseModule { +public: + BaseModule() = default; + + virtual ~BaseModule() = default; + + virtual void RegisterRequestHandlers() = 0; + + virtual void OnRequest(std::unique_ptr request); + +protected: + std::string moduleName = MODULE_UNKNOWN; + std::map > requestHandlerMap; +}; +} + +#endif //BOARD_MINDSTUDIO_BOARD_SERVER_CORE_INCLUDE_BASEMODULE_H_ diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/BasePlugin.h b/plugins/mindstudio-insight-plugins/plugin_core/include/BasePlugin.h new file mode 100644 index 0000000000000000000000000000000000000000..6b95b28869c53a12fb28c9c1f976cec2ac1c9ea8 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/BasePlugin.h @@ -0,0 +1,37 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MSINSIGHT_BASEPLUGIN_H +#define MSINSIGHT_BASEPLUGIN_H +#include +#include "BaseModule.h" +#include "ProtocolUtil.h" +#include "ApiHandler.h" + +namespace Dic::Core { +class BasePlugin { +public: + explicit BasePlugin(std::string pluginName): pluginName_(pluginName) {}; + virtual ~BasePlugin() = default; + std::string GetPluginName() { return pluginName_; }; + + virtual std::unique_ptr GetModule() { return nullptr; }; + + virtual std::unique_ptr GetProtocolUtil() { return nullptr; }; + + virtual std::map> GetAllHandlers() { return {}; }; + + virtual std::vector GetModuleConfig() { return {}; }; + + virtual uint8_t GetOrder() { return UINT8_MAX; }; + +protected: + std::string pluginName_; +}; + +struct PluginRegister { + explicit PluginRegister(std::unique_ptr plugin); +}; +} + +#endif // MSINSIGHT_BASEPLUGIN_H diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/Logger.h b/plugins/mindstudio-insight-plugins/plugin_core/include/Logger.h new file mode 100644 index 0000000000000000000000000000000000000000..46cb8fe5b2d8928cf06fc90f268d9520e7845083 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/Logger.h @@ -0,0 +1,52 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef BOARD_LOGGER_H +#define BOARD_LOGGER_H + +#include +#include + +namespace Insight { +enum class LogRank:int { + Error= 0, + Warning, + Info, + Exception +}; + +class Logger { +public: + static Logger &GetLogger(); + + template + Logger &operator<<(const T &val) noexcept + { + std::cout << val; + return *this; + } + + Logger &operator<<(std::ostream &func(std::ostream &os)) noexcept + { + std::cout << func; + return *this; + } + + Logger &start(std::string_view rank); +private: + Logger() = default; + ~Logger() = default; +}; + +static inline Logger &LOG(enum Insight::LogRank level) +{ + static std::unordered_map levelsMap = { + {LogRank::Error, "Error"}, + {LogRank::Info, "Info"}, + {LogRank::Warning, "Warn"}, + {LogRank::Exception, "Exception"} + }; + return Logger::GetLogger().start(levelsMap[level]); +} +} +#endif //BOARD_LOGGER_H diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/ModuleRequestHeadler.h b/plugins/mindstudio-insight-plugins/plugin_core/include/ModuleRequestHeadler.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb29e1683e32d68f48d6c902675c62e6c683422 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/ModuleRequestHeadler.h @@ -0,0 +1,33 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MSINSIGHT_REQUEST_HEADLER_H +#define MSINSIGHT_REQUEST_HEADLER_H +#include "ProtocolUtil.h" +#include "WsSender.h" +namespace Dic::Module { +using namespace Dic::Protocol; +class ModuleRequestHandler { +public: + ModuleRequestHandler() = default; + virtual ~ModuleRequestHandler() = default; + + virtual const std::string GetError(); + + virtual void HandleRequest(std::unique_ptr requestPtr) = 0; + + virtual bool IsAsync(); + +public: + static void SetBaseResponse(const Request& request, Response &response); + + static void SetResponseResult(Response& response, bool result, const std::string& errorMsg = "", const int errCode = UNKNOW_ERROR); +protected: + std::string command; + std::string error; + std::string moduleName = MODULE_UNKNOWN; + bool async = true; +}; +} + +#endif //BOARD_MINDSTUDIO_BOARD_SERVER_CORE_INCLUDE_MODULEREQUESTHEADLER_H_ diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/PluginsManager.h b/plugins/mindstudio-insight-plugins/plugin_core/include/PluginsManager.h new file mode 100644 index 0000000000000000000000000000000000000000..2cfbf1f7f97122b29b63b61cebdb571532e86c75 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/PluginsManager.h @@ -0,0 +1,28 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef BOARD_PLUGINSMANAGER_H +#define BOARD_PLUGINSMANAGER_H + +#include "map" +#include "memory" + +#include "BasePlugin.h" + +namespace Dic::Core { +class PluginsManager { +public: + static PluginsManager &Instance(); + bool RegisterPlugin(std::unique_ptr plugin); + std::map>& GetAllPlugins(); + static void LoadPlugins(); +private: + PluginsManager() = default; + ~PluginsManager() = default; + +private: + std::map> pluginsMap_; +}; +} + +#endif // BOARD_PLUGINSMANAGER_H diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/ProtocolUtil.h b/plugins/mindstudio-insight-plugins/plugin_core/include/ProtocolUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..670ace0817eaf32c3e3c5510aa3db1bf5f7aad48 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/ProtocolUtil.h @@ -0,0 +1,121 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MSINSIGHT_PROTOCOL_UTIL_H +#define MSINSIGHT_PROTOCOL_UTIL_H +#include +#include +#include +#include +#include +#include +#include +#include "rapidjson.h" +#include "document.h" +#include "writer.h" +#include "stringbuffer.h" +namespace Dic { +using json_t = rapidjson::Value; +using document_t = rapidjson::Document; +using namespace rapidjson; +const static int UNKNOW_ERROR = 1001; +const std::string MODULE_UNKNOWN = "unknown"; +namespace Protocol { +struct ErrorMessage { + int code = 0; + std::string message; +}; +struct ProtocolMessage { + enum class Type: int { + REQUEST = 0, RESPONSE, EVENT, NONE + }; + virtual ~ProtocolMessage() = default; + unsigned int id = 0; + ProtocolMessage::Type type = Type::NONE; + std::string moduleName = MODULE_UNKNOWN; + std::optional resultCallbackId; +}; +struct Request : public ProtocolMessage { + explicit Request(const std::string &command) : command(command) + { + type = ProtocolMessage::Type::REQUEST; + } + explicit Request(std::string_view command) : command(std::string(command)) + { + type = ProtocolMessage::Type::REQUEST; + } + ~Request() override = default; + std::string projectName; + std::string command; + std::tuple ParamVaild() const + { + return {true, ""}; + } +}; +struct Response: public ProtocolMessage { + explicit Response(const std::string &command) : command(command) + { + type = ProtocolMessage::Type::RESPONSE; + } + ~Response() override = default; + unsigned int requestId = 0; + bool result = false; + std::string command; + std::optional error; +}; + +struct Event: public ProtocolMessage { + explicit Event(const std::string &e) : event(e) + { + type = ProtocolMessage::Type::EVENT; + } + ~Event() override = default; + std::string event; + bool result = true; +}; + +class ProtocolUtil { +public: + ProtocolUtil() = default; + ~ProtocolUtil() = default; + + void Register(); + void UnRegister(); + + std::unique_ptr FromJson(const json_t& requestJson, std::string& error); + std::optional ToJson(const Response& response, std::string& error); + std::optional ToJson(const Event& event, std::string& error); + + // set base info + // request + static bool SetRequestBaseInfo(Request &request, const json_t& json); + // response + static void SetResponseJsonBaseInfo(const Response& response, document_t & json); + // event + static void SetEventJsonBaseInfo(const Event& event, document_t& json); + +protected: + std::mutex mutex; + using JsonToRequestFunc = std::function(const json_t&, std::string&)>; + using ResponseToJsonFunc = std::function(const Response&)>; + using EventToJsonFunc = std::function(const Event&)>; + std::map jsonToReqFactory; + std::map resToJsonFactory; + std::map eventToJsonFactory; +private: + virtual void RegisterJsonToRequestFuncs() = 0; + virtual void RegisterResponseToJsonFuncs() = 0; + virtual void RegisterEventToJsonFuncs() = 0; + // request + static bool IsRequest(const json_t &jsonRequest); + static std::string Command(const json_t &jsonRequest); + std::optional GetJsonToRequestFunc(const std::string& command); + // response + std::optional GetResponseToJsonFunc(const std::string& command); + // event + std::optional GetEventToJsonFunc(const std::string &event); +}; +} // namespace Protocol +} // namespace Dic +#endif // MSINSIGHT_PROTOCOL_UTIL_H + diff --git a/plugins/mindstudio-insight-plugins/plugin_core/include/WsSender.h b/plugins/mindstudio-insight-plugins/plugin_core/include/WsSender.h new file mode 100644 index 0000000000000000000000000000000000000000..b2ba6892230bf144f5eedee6875843046e11ee05 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/include/WsSender.h @@ -0,0 +1,11 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef MSINSIGHT_WSSENDER_H +#define MSINSIGHT_WSSENDER_H +#include "ProtocolUtil.h" +namespace Dic{ +void SendEvent(std::unique_ptr eventPtr); +void SendResponse(std::unique_ptr responsePtr, bool result, const std::string &errorMsg = "", const int errorCode = UNKNOW_ERROR); +} +#endif //MSINSIGHT_WSSENDER_H diff --git a/plugins/mindstudio-insight-plugins/plugin_core/lib/libmsinsight.so b/plugins/mindstudio-insight-plugins/plugin_core/lib/libmsinsight.so new file mode 100644 index 0000000000000000000000000000000000000000..f1933fd613c2359d5cc46c2f8e477f397afb323d Binary files /dev/null and b/plugins/mindstudio-insight-plugins/plugin_core/lib/libmsinsight.so differ diff --git a/plugins/mindstudio-insight-plugins/plugin_core/src/Logger.cpp b/plugins/mindstudio-insight-plugins/plugin_core/src/Logger.cpp new file mode 100644 index 0000000000000000000000000000000000000000..135a96ffb2ec60edb3cf772353d356f319e2bd92 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/src/Logger.cpp @@ -0,0 +1,31 @@ +/* +* Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include +#include +#include +#include "Logger.h" + +namespace Insight { +Logger &Logger::GetLogger() +{ + std::ios::sync_with_stdio(false); // improve the std::cout performance + static Logger errorLogger; + return errorLogger; +} + +Logger &Logger::start(std::string_view rank) +{ + const time_t current_time = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + std::tm local_time{}; +#ifdef _WIN32 + localtime_s(&local_time, ¤t_time); +#else + localtime_r(¤t_time, &local_time); +#endif + std::ostringstream oss; + oss << std::put_time(&local_time, "%Y-%m-%d %H-%M-%S"); + *this << std::endl << oss.str() << "|[" << rank << "] "; + return *this; +} +} diff --git a/plugins/mindstudio-insight-plugins/plugin_core/src/PluginsManager.cpp b/plugins/mindstudio-insight-plugins/plugin_core/src/PluginsManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cef7e2359ba25fd1047c6c4b1b0e72968a28133e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/plugin_core/src/PluginsManager.cpp @@ -0,0 +1,55 @@ +#include "PluginsManager.h" +#include + +using namespace Dic::Core; +#ifdef _WIN32 +#include +namespace fs = std::filesystem; +const static std::string EXT = ".dll"; +#else + +#include +#include + +namespace fs = std::experimental::filesystem; +const static std::string EXT = ".so"; +#endif + +PluginsManager &PluginsManager::Instance() { + static PluginsManager instance; + return instance; +} + +bool PluginsManager::RegisterPlugin(std::unique_ptr plugin) { + pluginsMap_.emplace(plugin->GetPluginName(), std::move(plugin)); + return true; +} + +void PluginsManager::LoadPlugins() { + auto pluginsDir = PLUGINS_DIR; + if (!fs::exists(pluginsDir)) { + return; + } + for (auto &dir: fs::directory_iterator(pluginsDir)) { + if (!fs::is_directory(dir)) { + continue; + } + for (auto &file: fs::directory_iterator(dir)) { + if (!fs::is_directory(file) && file.path().extension().string() == EXT) { +#ifdef _WIN32 +#else + dlopen(file.path().string().c_str(), RTLD_LAZY); +#endif + } + } + } + +} + +std::map> &PluginsManager::GetAllPlugins() { + return pluginsMap_; +} + +PluginRegister::PluginRegister(std::unique_ptr plugin) { + PluginsManager::Instance().RegisterPlugin(std::move(plugin)); +} diff --git a/plugins/mindstudio-insight-plugins/proto/CMakeLists.txt b/plugins/mindstudio-insight-plugins/proto/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0644e2dbb4155f40b4331153fbfed9ad3acdaf17 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/CMakeLists.txt @@ -0,0 +1,6 @@ +project(PROTO_GENERATE) +message(STATUS ${PROJECT_ROOT_DIR}) +file(GLOB_RECURSE PROTO_FILE "${PROJECT_ROOT_DIR}/proto/*.proto") +message(.) +ms_protobuf_generate(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR} PROTO_SRC PROTO_H ${PROTO_FILE}) + diff --git a/plugins/mindstudio-insight-plugins/proto/event.pb.cc b/plugins/mindstudio-insight-plugins/proto/event.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4be6162980b68fab426c900f247b6f01455a479 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/event.pb.cc @@ -0,0 +1,3065 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: event.proto + +#include "event.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<5> scc_info_Event_event_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_LogMessage_event_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_RequestedExitCode_event_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SessionLog_event_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SourceMetadata_event_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_summary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Summary_summary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TaggedRunMetadata_event_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_event_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_WatchdogConfig_event_2eproto; +namespace tensorboard { +class EventDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr file_version_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr graph_def_; + const ::tensorboard::Summary* summary_; + const ::tensorboard::LogMessage* log_message_; + const ::tensorboard::SessionLog* session_log_; + const ::tensorboard::TaggedRunMetadata* tagged_run_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr meta_graph_def_; +} _Event_default_instance_; +class SourceMetadataDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _SourceMetadata_default_instance_; +class LogMessageDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _LogMessage_default_instance_; +class SessionLogDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _SessionLog_default_instance_; +class TaggedRunMetadataDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TaggedRunMetadata_default_instance_; +class WatchdogConfigDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _WatchdogConfig_default_instance_; +class RequestedExitCodeDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _RequestedExitCode_default_instance_; +class WorkerHeartbeatRequestDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _WorkerHeartbeatRequest_default_instance_; +class WorkerHeartbeatResponseDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _WorkerHeartbeatResponse_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_Event_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_Event_default_instance_; + new (ptr) ::tensorboard::Event(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::Event::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<5> scc_info_Event_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 5, 0, InitDefaultsscc_info_Event_event_2eproto}, { + &scc_info_Summary_summary_2eproto.base, + &scc_info_LogMessage_event_2eproto.base, + &scc_info_SessionLog_event_2eproto.base, + &scc_info_TaggedRunMetadata_event_2eproto.base, + &scc_info_SourceMetadata_event_2eproto.base,}}; + +static void InitDefaultsscc_info_LogMessage_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_LogMessage_default_instance_; + new (ptr) ::tensorboard::LogMessage(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::LogMessage::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_LogMessage_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_LogMessage_event_2eproto}, {}}; + +static void InitDefaultsscc_info_RequestedExitCode_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_RequestedExitCode_default_instance_; + new (ptr) ::tensorboard::RequestedExitCode(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::RequestedExitCode::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_RequestedExitCode_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_RequestedExitCode_event_2eproto}, {}}; + +static void InitDefaultsscc_info_SessionLog_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_SessionLog_default_instance_; + new (ptr) ::tensorboard::SessionLog(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::SessionLog::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SessionLog_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_SessionLog_event_2eproto}, {}}; + +static void InitDefaultsscc_info_SourceMetadata_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_SourceMetadata_default_instance_; + new (ptr) ::tensorboard::SourceMetadata(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::SourceMetadata::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SourceMetadata_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_SourceMetadata_event_2eproto}, {}}; + +static void InitDefaultsscc_info_TaggedRunMetadata_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_TaggedRunMetadata_default_instance_; + new (ptr) ::tensorboard::TaggedRunMetadata(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::TaggedRunMetadata::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TaggedRunMetadata_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TaggedRunMetadata_event_2eproto}, {}}; + +static void InitDefaultsscc_info_WatchdogConfig_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_WatchdogConfig_default_instance_; + new (ptr) ::tensorboard::WatchdogConfig(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::WatchdogConfig::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_WatchdogConfig_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_WatchdogConfig_event_2eproto}, {}}; + +static void InitDefaultsscc_info_WorkerHeartbeatRequest_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_WorkerHeartbeatRequest_default_instance_; + new (ptr) ::tensorboard::WorkerHeartbeatRequest(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::WorkerHeartbeatRequest::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_WorkerHeartbeatRequest_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 2, 0, InitDefaultsscc_info_WorkerHeartbeatRequest_event_2eproto}, { + &scc_info_WatchdogConfig_event_2eproto.base, + &scc_info_RequestedExitCode_event_2eproto.base,}}; + +static void InitDefaultsscc_info_WorkerHeartbeatResponse_event_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_WorkerHeartbeatResponse_default_instance_; + new (ptr) ::tensorboard::WorkerHeartbeatResponse(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::WorkerHeartbeatResponse::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_WorkerHeartbeatResponse_event_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_WorkerHeartbeatResponse_event_2eproto}, { + &scc_info_Event_event_2eproto.base,}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_event_2eproto[9]; +static const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* file_level_enum_descriptors_event_2eproto[4]; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_event_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_event_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Event, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Event, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Event, wall_time_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Event, step_), + offsetof(::tensorboard::EventDefaultTypeInternal, file_version_), + offsetof(::tensorboard::EventDefaultTypeInternal, graph_def_), + offsetof(::tensorboard::EventDefaultTypeInternal, summary_), + offsetof(::tensorboard::EventDefaultTypeInternal, log_message_), + offsetof(::tensorboard::EventDefaultTypeInternal, session_log_), + offsetof(::tensorboard::EventDefaultTypeInternal, tagged_run_metadata_), + offsetof(::tensorboard::EventDefaultTypeInternal, meta_graph_def_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Event, source_metadata_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Event, what_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SourceMetadata, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SourceMetadata, writer_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::LogMessage, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::LogMessage, level_), + PROTOBUF_FIELD_OFFSET(::tensorboard::LogMessage, message_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SessionLog, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SessionLog, status_), + PROTOBUF_FIELD_OFFSET(::tensorboard::SessionLog, checkpoint_path_), + PROTOBUF_FIELD_OFFSET(::tensorboard::SessionLog, msg_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TaggedRunMetadata, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TaggedRunMetadata, tag_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TaggedRunMetadata, run_metadata_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::WatchdogConfig, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::WatchdogConfig, timeout_ms_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::RequestedExitCode, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::RequestedExitCode, exit_code_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatRequest, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatRequest, shutdown_mode_), + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatRequest, watchdog_config_), + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatRequest, exit_code_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatResponse, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatResponse, health_status_), + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatResponse, worker_log_), + PROTOBUF_FIELD_OFFSET(::tensorboard::WorkerHeartbeatResponse, hostname_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::Event)}, + { 16, -1, sizeof(::tensorboard::SourceMetadata)}, + { 22, -1, sizeof(::tensorboard::LogMessage)}, + { 29, -1, sizeof(::tensorboard::SessionLog)}, + { 37, -1, sizeof(::tensorboard::TaggedRunMetadata)}, + { 44, -1, sizeof(::tensorboard::WatchdogConfig)}, + { 50, -1, sizeof(::tensorboard::RequestedExitCode)}, + { 56, -1, sizeof(::tensorboard::WorkerHeartbeatRequest)}, + { 64, -1, sizeof(::tensorboard::WorkerHeartbeatResponse)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_Event_default_instance_), + reinterpret_cast(&::tensorboard::_SourceMetadata_default_instance_), + reinterpret_cast(&::tensorboard::_LogMessage_default_instance_), + reinterpret_cast(&::tensorboard::_SessionLog_default_instance_), + reinterpret_cast(&::tensorboard::_TaggedRunMetadata_default_instance_), + reinterpret_cast(&::tensorboard::_WatchdogConfig_default_instance_), + reinterpret_cast(&::tensorboard::_RequestedExitCode_default_instance_), + reinterpret_cast(&::tensorboard::_WorkerHeartbeatRequest_default_instance_), + reinterpret_cast(&::tensorboard::_WorkerHeartbeatResponse_default_instance_), +}; + +const char descriptor_table_protodef_event_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\013event.proto\022\013tensorboard\032\rsummary.prot" + "o\"\371\002\n\005Event\022\021\n\twall_time\030\001 \001(\001\022\014\n\004step\030\002" + " \001(\003\022\026\n\014file_version\030\003 \001(\tH\000\022\023\n\tgraph_de" + "f\030\004 \001(\014H\000\022\'\n\007summary\030\005 \001(\0132\024.tensorboard" + ".SummaryH\000\0222\n\013log_message\030\006 \001(\0132\027.tensor" + "board.LogMessageB\002\030\001H\000\022.\n\013session_log\030\007 " + "\001(\0132\027.tensorboard.SessionLogH\000\022=\n\023tagged" + "_run_metadata\030\010 \001(\0132\036.tensorboard.Tagged" + "RunMetadataH\000\022\030\n\016meta_graph_def\030\t \001(\014H\000\022" + "4\n\017source_metadata\030\n \001(\0132\033.tensorboard.S" + "ourceMetadataB\006\n\004what\" \n\016SourceMetadata\022" + "\016\n\006writer\030\001 \001(\t\"\242\001\n\nLogMessage\022,\n\005level\030" + "\001 \001(\0162\035.tensorboard.LogMessage.Level\022\017\n\007" + "message\030\002 \001(\t\"Q\n\005Level\022\013\n\007UNKNOWN\020\000\022\r\n\tD" + "EBUGGING\020\n\022\010\n\004INFO\020\024\022\010\n\004WARN\020\036\022\t\n\005ERROR\020" + "(\022\t\n\005FATAL\0202\032\002\030\001:\002\030\001\"\267\001\n\nSessionLog\0225\n\006s" + "tatus\030\001 \001(\0162%.tensorboard.SessionLog.Ses" + "sionStatus\022\027\n\017checkpoint_path\030\002 \001(\t\022\013\n\003m" + "sg\030\003 \001(\t\"L\n\rSessionStatus\022\026\n\022STATUS_UNSP" + "ECIFIED\020\000\022\t\n\005START\020\001\022\010\n\004STOP\020\002\022\016\n\nCHECKP" + "OINT\020\003\"6\n\021TaggedRunMetadata\022\013\n\003tag\030\001 \001(\t" + "\022\024\n\014run_metadata\030\002 \001(\014\"$\n\016WatchdogConfig" + "\022\022\n\ntimeout_ms\030\001 \001(\003\"&\n\021RequestedExitCod" + "e\022\021\n\texit_code\030\001 \001(\005\"\271\001\n\026WorkerHeartbeat" + "Request\0226\n\rshutdown_mode\030\001 \001(\0162\037.tensorb" + "oard.WorkerShutdownMode\0224\n\017watchdog_conf" + "ig\030\002 \001(\0132\033.tensorboard.WatchdogConfig\0221\n" + "\texit_code\030\003 \001(\0132\036.tensorboard.Requested" + "ExitCode\"\205\001\n\027WorkerHeartbeatResponse\0220\n\r" + "health_status\030\001 \001(\0162\031.tensorboard.Worker" + "Health\022&\n\nworker_log\030\002 \003(\0132\022.tensorboard" + ".Event\022\020\n\010hostname\030\003 \001(\t*[\n\014WorkerHealth" + "\022\006\n\002OK\020\000\022\034\n\030RECEIVED_SHUTDOWN_SIGNAL\020\001\022\022" + "\n\016INTERNAL_ERROR\020\002\022\021\n\rSHUTTING_DOWN\020\003*k\n" + "\022WorkerShutdownMode\022\013\n\007DEFAULT\020\000\022\022\n\016NOT_" + "CONFIGURED\020\001\022\030\n\024WAIT_FOR_COORDINATOR\020\002\022\032" + "\n\026SHUTDOWN_AFTER_TIMEOUT\020\003Bp\n\023org.tensor" + "flow.utilB\013EventProtosP\001ZGgithub.com/ten" + "sorflow/tensorflow/tensorflow/go/core/ut" + "il/event_go_proto\370\001\001b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_event_2eproto_deps[1] = { + &::descriptor_table_summary_2eproto, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_event_2eproto_sccs[9] = { + &scc_info_Event_event_2eproto.base, + &scc_info_LogMessage_event_2eproto.base, + &scc_info_RequestedExitCode_event_2eproto.base, + &scc_info_SessionLog_event_2eproto.base, + &scc_info_SourceMetadata_event_2eproto.base, + &scc_info_TaggedRunMetadata_event_2eproto.base, + &scc_info_WatchdogConfig_event_2eproto.base, + &scc_info_WorkerHeartbeatRequest_event_2eproto.base, + &scc_info_WorkerHeartbeatResponse_event_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_event_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_event_2eproto = { + false, false, descriptor_table_protodef_event_2eproto, "event.proto", 1588, + &descriptor_table_event_2eproto_once, descriptor_table_event_2eproto_sccs, descriptor_table_event_2eproto_deps, 9, 1, + schemas, file_default_instances, TableStruct_event_2eproto::offsets, + file_level_metadata_event_2eproto, 9, file_level_enum_descriptors_event_2eproto, file_level_service_descriptors_event_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_event_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_event_2eproto)), true); +namespace tensorboard { +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* LogMessage_Level_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_event_2eproto); + return file_level_enum_descriptors_event_2eproto[0]; +} +bool LogMessage_Level_IsValid(int value) { + switch (value) { + case 0: + case 10: + case 20: + case 30: + case 40: + case 50: + return true; + default: + return false; + } +} + +#if (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +constexpr LogMessage_Level LogMessage::UNKNOWN; +constexpr LogMessage_Level LogMessage::DEBUGGING; +constexpr LogMessage_Level LogMessage::INFO; +constexpr LogMessage_Level LogMessage::WARN; +constexpr LogMessage_Level LogMessage::ERROR; +constexpr LogMessage_Level LogMessage::FATAL; +constexpr LogMessage_Level LogMessage::Level_MIN; +constexpr LogMessage_Level LogMessage::Level_MAX; +constexpr int LogMessage::Level_ARRAYSIZE; +#endif // (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* SessionLog_SessionStatus_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_event_2eproto); + return file_level_enum_descriptors_event_2eproto[1]; +} +bool SessionLog_SessionStatus_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + return true; + default: + return false; + } +} + +#if (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +constexpr SessionLog_SessionStatus SessionLog::STATUS_UNSPECIFIED; +constexpr SessionLog_SessionStatus SessionLog::START; +constexpr SessionLog_SessionStatus SessionLog::STOP; +constexpr SessionLog_SessionStatus SessionLog::CHECKPOINT; +constexpr SessionLog_SessionStatus SessionLog::SessionStatus_MIN; +constexpr SessionLog_SessionStatus SessionLog::SessionStatus_MAX; +constexpr int SessionLog::SessionStatus_ARRAYSIZE; +#endif // (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* WorkerHealth_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_event_2eproto); + return file_level_enum_descriptors_event_2eproto[2]; +} +bool WorkerHealth_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + return true; + default: + return false; + } +} + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* WorkerShutdownMode_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_event_2eproto); + return file_level_enum_descriptors_event_2eproto[3]; +} +bool WorkerShutdownMode_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + return true; + default: + return false; + } +} + + +// =================================================================== + +void Event::InitAsDefaultInstance() { + ::tensorboard::_Event_default_instance_.file_version_.UnsafeSetDefault( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::tensorboard::_Event_default_instance_.graph_def_.UnsafeSetDefault( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::tensorboard::_Event_default_instance_.summary_ = const_cast< ::tensorboard::Summary*>( + ::tensorboard::Summary::internal_default_instance()); + ::tensorboard::_Event_default_instance_.log_message_ = const_cast< ::tensorboard::LogMessage*>( + ::tensorboard::LogMessage::internal_default_instance()); + ::tensorboard::_Event_default_instance_.session_log_ = const_cast< ::tensorboard::SessionLog*>( + ::tensorboard::SessionLog::internal_default_instance()); + ::tensorboard::_Event_default_instance_.tagged_run_metadata_ = const_cast< ::tensorboard::TaggedRunMetadata*>( + ::tensorboard::TaggedRunMetadata::internal_default_instance()); + ::tensorboard::_Event_default_instance_.meta_graph_def_.UnsafeSetDefault( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::tensorboard::_Event_default_instance_._instance.get_mutable()->source_metadata_ = const_cast< ::tensorboard::SourceMetadata*>( + ::tensorboard::SourceMetadata::internal_default_instance()); +} +class Event::_Internal { + public: + static const ::tensorboard::Summary& summary(const Event* msg); + static const ::tensorboard::LogMessage& log_message(const Event* msg); + static const ::tensorboard::SessionLog& session_log(const Event* msg); + static const ::tensorboard::TaggedRunMetadata& tagged_run_metadata(const Event* msg); + static const ::tensorboard::SourceMetadata& source_metadata(const Event* msg); +}; + +const ::tensorboard::Summary& +Event::_Internal::summary(const Event* msg) { + return *msg->what_.summary_; +} +const ::tensorboard::LogMessage& +Event::_Internal::log_message(const Event* msg) { + return *msg->what_.log_message_; +} +const ::tensorboard::SessionLog& +Event::_Internal::session_log(const Event* msg) { + return *msg->what_.session_log_; +} +const ::tensorboard::TaggedRunMetadata& +Event::_Internal::tagged_run_metadata(const Event* msg) { + return *msg->what_.tagged_run_metadata_; +} +const ::tensorboard::SourceMetadata& +Event::_Internal::source_metadata(const Event* msg) { + return *msg->source_metadata_; +} +void Event::set_allocated_summary(::tensorboard::Summary* summary) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (summary) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(summary)->GetArena(); + if (message_arena != submessage_arena) { + summary = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, summary, submessage_arena); + } + set_has_summary(); + what_.summary_ = summary; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.summary) +} +void Event::clear_summary() { + if (_internal_has_summary()) { + if (GetArena() == nullptr) { + delete what_.summary_; + } + clear_has_what(); + } +} +void Event::set_allocated_log_message(::tensorboard::LogMessage* log_message) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (log_message) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(log_message); + if (message_arena != submessage_arena) { + log_message = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, log_message, submessage_arena); + } + set_has_log_message(); + what_.log_message_ = log_message; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.log_message) +} +void Event::set_allocated_session_log(::tensorboard::SessionLog* session_log) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (session_log) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(session_log); + if (message_arena != submessage_arena) { + session_log = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, session_log, submessage_arena); + } + set_has_session_log(); + what_.session_log_ = session_log; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.session_log) +} +void Event::set_allocated_tagged_run_metadata(::tensorboard::TaggedRunMetadata* tagged_run_metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (tagged_run_metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(tagged_run_metadata); + if (message_arena != submessage_arena) { + tagged_run_metadata = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tagged_run_metadata, submessage_arena); + } + set_has_tagged_run_metadata(); + what_.tagged_run_metadata_ = tagged_run_metadata; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.tagged_run_metadata) +} +Event::Event(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.Event) +} +Event::Event(const Event& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_source_metadata()) { + source_metadata_ = new ::tensorboard::SourceMetadata(*from.source_metadata_); + } else { + source_metadata_ = nullptr; + } + ::memcpy(&wall_time_, &from.wall_time_, + static_cast(reinterpret_cast(&step_) - + reinterpret_cast(&wall_time_)) + sizeof(step_)); + clear_has_what(); + switch (from.what_case()) { + case kFileVersion: { + _internal_set_file_version(from._internal_file_version()); + break; + } + case kGraphDef: { + _internal_set_graph_def(from._internal_graph_def()); + break; + } + case kSummary: { + _internal_mutable_summary()->::tensorboard::Summary::MergeFrom(from._internal_summary()); + break; + } + case kLogMessage: { + _internal_mutable_log_message()->::tensorboard::LogMessage::MergeFrom(from._internal_log_message()); + break; + } + case kSessionLog: { + _internal_mutable_session_log()->::tensorboard::SessionLog::MergeFrom(from._internal_session_log()); + break; + } + case kTaggedRunMetadata: { + _internal_mutable_tagged_run_metadata()->::tensorboard::TaggedRunMetadata::MergeFrom(from._internal_tagged_run_metadata()); + break; + } + case kMetaGraphDef: { + _internal_set_meta_graph_def(from._internal_meta_graph_def()); + break; + } + case WHAT_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:tensorboard.Event) +} + +void Event::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Event_event_2eproto.base); + ::memset(&source_metadata_, 0, static_cast( + reinterpret_cast(&step_) - + reinterpret_cast(&source_metadata_)) + sizeof(step_)); + clear_has_what(); +} + +Event::~Event() { + // @@protoc_insertion_point(destructor:tensorboard.Event) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Event::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete source_metadata_; + if (has_what()) { + clear_what(); + } +} + +void Event::ArenaDtor(void* object) { + Event* _this = reinterpret_cast< Event* >(object); + (void)_this; +} +void Event::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Event::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Event& Event::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Event_event_2eproto.base); + return *internal_default_instance(); +} + + +void Event::clear_what() { +// @@protoc_insertion_point(one_of_clear_start:tensorboard.Event) + switch (what_case()) { + case kFileVersion: { + what_.file_version_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + break; + } + case kGraphDef: { + what_.graph_def_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + break; + } + case kSummary: { + if (GetArena() == nullptr) { + delete what_.summary_; + } + break; + } + case kLogMessage: { + if (GetArena() == nullptr) { + delete what_.log_message_; + } + break; + } + case kSessionLog: { + if (GetArena() == nullptr) { + delete what_.session_log_; + } + break; + } + case kTaggedRunMetadata: { + if (GetArena() == nullptr) { + delete what_.tagged_run_metadata_; + } + break; + } + case kMetaGraphDef: { + what_.meta_graph_def_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + break; + } + case WHAT_NOT_SET: { + break; + } + } + _oneof_case_[0] = WHAT_NOT_SET; +} + + +void Event::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.Event) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + if (GetArena() == nullptr && source_metadata_ != nullptr) { + delete source_metadata_; + } + source_metadata_ = nullptr; + ::memset(&wall_time_, 0, static_cast( + reinterpret_cast(&step_) - + reinterpret_cast(&wall_time_)) + sizeof(step_)); + clear_what(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Event::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // double wall_time = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 9)) { + wall_time_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // int64 step = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + step_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string file_version = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_file_version(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.Event.file_version")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes graph_def = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_graph_def(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.Summary summary = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ctx->ParseMessage(_internal_mutable_summary(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.LogMessage log_message = 6 [deprecated = true]; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ctx->ParseMessage(_internal_mutable_log_message(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.SessionLog session_log = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ctx->ParseMessage(_internal_mutable_session_log(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.TaggedRunMetadata tagged_run_metadata = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr = ctx->ParseMessage(_internal_mutable_tagged_run_metadata(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes meta_graph_def = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + auto str = _internal_mutable_meta_graph_def(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.SourceMetadata source_metadata = 10; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + ptr = ctx->ParseMessage(_internal_mutable_source_metadata(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Event::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.Event) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // double wall_time = 1; + if (!(this->wall_time() <= 0 && this->wall_time() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(1, this->_internal_wall_time(), target); + } + + // int64 step = 2; + if (this->step() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(2, this->_internal_step(), target); + } + + // string file_version = 3; + if (_internal_has_file_version()) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_file_version().data(), static_cast(this->_internal_file_version().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.Event.file_version"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_file_version(), target); + } + + // bytes graph_def = 4; + if (_internal_has_graph_def()) { + target = stream->WriteBytesMaybeAliased( + 4, this->_internal_graph_def(), target); + } + + // .tensorboard.Summary summary = 5; + if (_internal_has_summary()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 5, _Internal::summary(this), target, stream); + } + + // .tensorboard.LogMessage log_message = 6 [deprecated = true]; + if (_internal_has_log_message()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 6, _Internal::log_message(this), target, stream); + } + + // .tensorboard.SessionLog session_log = 7; + if (_internal_has_session_log()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 7, _Internal::session_log(this), target, stream); + } + + // .tensorboard.TaggedRunMetadata tagged_run_metadata = 8; + if (_internal_has_tagged_run_metadata()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 8, _Internal::tagged_run_metadata(this), target, stream); + } + + // bytes meta_graph_def = 9; + if (_internal_has_meta_graph_def()) { + target = stream->WriteBytesMaybeAliased( + 9, this->_internal_meta_graph_def(), target); + } + + // .tensorboard.SourceMetadata source_metadata = 10; + if (this->has_source_metadata()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 10, _Internal::source_metadata(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.Event) + return target; +} + +size_t Event::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.Event) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // .tensorboard.SourceMetadata source_metadata = 10; + if (this->has_source_metadata()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *source_metadata_); + } + + // double wall_time = 1; + if (!(this->wall_time() <= 0 && this->wall_time() >= 0)) { + total_size += 1 + 8; + } + + // int64 step = 2; + if (this->step() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_step()); + } + + switch (what_case()) { + // string file_version = 3; + case kFileVersion: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_file_version()); + break; + } + // bytes graph_def = 4; + case kGraphDef: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_graph_def()); + break; + } + // .tensorboard.Summary summary = 5; + case kSummary: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.summary_); + break; + } + // .tensorboard.LogMessage log_message = 6 [deprecated = true]; + case kLogMessage: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.log_message_); + break; + } + // .tensorboard.SessionLog session_log = 7; + case kSessionLog: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.session_log_); + break; + } + // .tensorboard.TaggedRunMetadata tagged_run_metadata = 8; + case kTaggedRunMetadata: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.tagged_run_metadata_); + break; + } + // bytes meta_graph_def = 9; + case kMetaGraphDef: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_meta_graph_def()); + break; + } + case WHAT_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Event::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.Event) + GOOGLE_DCHECK_NE(&from, this); + const Event* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.Event) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.Event) + MergeFrom(*source); + } +} + +void Event::MergeFrom(const Event& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.Event) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.has_source_metadata()) { + _internal_mutable_source_metadata()->::tensorboard::SourceMetadata::MergeFrom(from._internal_source_metadata()); + } + if (!(from.wall_time() <= 0 && from.wall_time() >= 0)) { + _internal_set_wall_time(from._internal_wall_time()); + } + if (from.step() != 0) { + _internal_set_step(from._internal_step()); + } + switch (from.what_case()) { + case kFileVersion: { + _internal_set_file_version(from._internal_file_version()); + break; + } + case kGraphDef: { + _internal_set_graph_def(from._internal_graph_def()); + break; + } + case kSummary: { + _internal_mutable_summary()->::tensorboard::Summary::MergeFrom(from._internal_summary()); + break; + } + case kLogMessage: { + _internal_mutable_log_message()->::tensorboard::LogMessage::MergeFrom(from._internal_log_message()); + break; + } + case kSessionLog: { + _internal_mutable_session_log()->::tensorboard::SessionLog::MergeFrom(from._internal_session_log()); + break; + } + case kTaggedRunMetadata: { + _internal_mutable_tagged_run_metadata()->::tensorboard::TaggedRunMetadata::MergeFrom(from._internal_tagged_run_metadata()); + break; + } + case kMetaGraphDef: { + _internal_set_meta_graph_def(from._internal_meta_graph_def()); + break; + } + case WHAT_NOT_SET: { + break; + } + } +} + +void Event::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.Event) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Event::CopyFrom(const Event& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.Event) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Event::IsInitialized() const { + return true; +} + +void Event::InternalSwap(Event* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Event, step_) + + sizeof(Event::step_) + - PROTOBUF_FIELD_OFFSET(Event, source_metadata_)>( + reinterpret_cast(&source_metadata_), + reinterpret_cast(&other->source_metadata_)); + swap(what_, other->what_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Event::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void SourceMetadata::InitAsDefaultInstance() { +} +class SourceMetadata::_Internal { + public: +}; + +SourceMetadata::SourceMetadata(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.SourceMetadata) +} +SourceMetadata::SourceMetadata(const SourceMetadata& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + writer_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_writer().empty()) { + writer_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_writer(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:tensorboard.SourceMetadata) +} + +void SourceMetadata::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SourceMetadata_event_2eproto.base); + writer_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +SourceMetadata::~SourceMetadata() { + // @@protoc_insertion_point(destructor:tensorboard.SourceMetadata) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void SourceMetadata::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + writer_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void SourceMetadata::ArenaDtor(void* object) { + SourceMetadata* _this = reinterpret_cast< SourceMetadata* >(object); + (void)_this; +} +void SourceMetadata::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void SourceMetadata::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const SourceMetadata& SourceMetadata::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SourceMetadata_event_2eproto.base); + return *internal_default_instance(); +} + + +void SourceMetadata::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.SourceMetadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + writer_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* SourceMetadata::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string writer = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_writer(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SourceMetadata.writer")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* SourceMetadata::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.SourceMetadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string writer = 1; + if (this->writer().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_writer().data(), static_cast(this->_internal_writer().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SourceMetadata.writer"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_writer(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.SourceMetadata) + return target; +} + +size_t SourceMetadata::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.SourceMetadata) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string writer = 1; + if (this->writer().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_writer()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void SourceMetadata::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.SourceMetadata) + GOOGLE_DCHECK_NE(&from, this); + const SourceMetadata* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.SourceMetadata) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.SourceMetadata) + MergeFrom(*source); + } +} + +void SourceMetadata::MergeFrom(const SourceMetadata& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.SourceMetadata) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.writer().size() > 0) { + _internal_set_writer(from._internal_writer()); + } +} + +void SourceMetadata::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.SourceMetadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SourceMetadata::CopyFrom(const SourceMetadata& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.SourceMetadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SourceMetadata::IsInitialized() const { + return true; +} + +void SourceMetadata::InternalSwap(SourceMetadata* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + writer_.Swap(&other->writer_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata SourceMetadata::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void LogMessage::InitAsDefaultInstance() { +} +class LogMessage::_Internal { + public: +}; + +LogMessage::LogMessage(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.LogMessage) +} +LogMessage::LogMessage(const LogMessage& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + message_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_message().empty()) { + message_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_message(), + GetArena()); + } + level_ = from.level_; + // @@protoc_insertion_point(copy_constructor:tensorboard.LogMessage) +} + +void LogMessage::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_LogMessage_event_2eproto.base); + message_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + level_ = 0; +} + +LogMessage::~LogMessage() { + // @@protoc_insertion_point(destructor:tensorboard.LogMessage) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void LogMessage::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + message_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void LogMessage::ArenaDtor(void* object) { + LogMessage* _this = reinterpret_cast< LogMessage* >(object); + (void)_this; +} +void LogMessage::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void LogMessage::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const LogMessage& LogMessage::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_LogMessage_event_2eproto.base); + return *internal_default_instance(); +} + + +void LogMessage::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.LogMessage) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + message_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + level_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* LogMessage::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.LogMessage.Level level = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_level(static_cast<::tensorboard::LogMessage_Level>(val)); + } else goto handle_unusual; + continue; + // string message = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_message(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.LogMessage.message")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* LogMessage::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.LogMessage) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.LogMessage.Level level = 1; + if (this->level() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_level(), target); + } + + // string message = 2; + if (this->message().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_message().data(), static_cast(this->_internal_message().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.LogMessage.message"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_message(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.LogMessage) + return target; +} + +size_t LogMessage::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.LogMessage) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string message = 2; + if (this->message().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_message()); + } + + // .tensorboard.LogMessage.Level level = 1; + if (this->level() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_level()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void LogMessage::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.LogMessage) + GOOGLE_DCHECK_NE(&from, this); + const LogMessage* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.LogMessage) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.LogMessage) + MergeFrom(*source); + } +} + +void LogMessage::MergeFrom(const LogMessage& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.LogMessage) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.message().size() > 0) { + _internal_set_message(from._internal_message()); + } + if (from.level() != 0) { + _internal_set_level(from._internal_level()); + } +} + +void LogMessage::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.LogMessage) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void LogMessage::CopyFrom(const LogMessage& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.LogMessage) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool LogMessage::IsInitialized() const { + return true; +} + +void LogMessage::InternalSwap(LogMessage* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + message_.Swap(&other->message_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(level_, other->level_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata LogMessage::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void SessionLog::InitAsDefaultInstance() { +} +class SessionLog::_Internal { + public: +}; + +SessionLog::SessionLog(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.SessionLog) +} +SessionLog::SessionLog(const SessionLog& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + checkpoint_path_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_checkpoint_path().empty()) { + checkpoint_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_checkpoint_path(), + GetArena()); + } + msg_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_msg().empty()) { + msg_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_msg(), + GetArena()); + } + status_ = from.status_; + // @@protoc_insertion_point(copy_constructor:tensorboard.SessionLog) +} + +void SessionLog::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SessionLog_event_2eproto.base); + checkpoint_path_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + msg_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + status_ = 0; +} + +SessionLog::~SessionLog() { + // @@protoc_insertion_point(destructor:tensorboard.SessionLog) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void SessionLog::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + checkpoint_path_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + msg_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void SessionLog::ArenaDtor(void* object) { + SessionLog* _this = reinterpret_cast< SessionLog* >(object); + (void)_this; +} +void SessionLog::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void SessionLog::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const SessionLog& SessionLog::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SessionLog_event_2eproto.base); + return *internal_default_instance(); +} + + +void SessionLog::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.SessionLog) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + checkpoint_path_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + msg_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + status_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* SessionLog::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.SessionLog.SessionStatus status = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_status(static_cast<::tensorboard::SessionLog_SessionStatus>(val)); + } else goto handle_unusual; + continue; + // string checkpoint_path = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_checkpoint_path(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SessionLog.checkpoint_path")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string msg = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_msg(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SessionLog.msg")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* SessionLog::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.SessionLog) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.SessionLog.SessionStatus status = 1; + if (this->status() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_status(), target); + } + + // string checkpoint_path = 2; + if (this->checkpoint_path().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_checkpoint_path().data(), static_cast(this->_internal_checkpoint_path().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SessionLog.checkpoint_path"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_checkpoint_path(), target); + } + + // string msg = 3; + if (this->msg().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_msg().data(), static_cast(this->_internal_msg().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SessionLog.msg"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_msg(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.SessionLog) + return target; +} + +size_t SessionLog::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.SessionLog) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string checkpoint_path = 2; + if (this->checkpoint_path().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_checkpoint_path()); + } + + // string msg = 3; + if (this->msg().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_msg()); + } + + // .tensorboard.SessionLog.SessionStatus status = 1; + if (this->status() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_status()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void SessionLog::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.SessionLog) + GOOGLE_DCHECK_NE(&from, this); + const SessionLog* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.SessionLog) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.SessionLog) + MergeFrom(*source); + } +} + +void SessionLog::MergeFrom(const SessionLog& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.SessionLog) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.checkpoint_path().size() > 0) { + _internal_set_checkpoint_path(from._internal_checkpoint_path()); + } + if (from.msg().size() > 0) { + _internal_set_msg(from._internal_msg()); + } + if (from.status() != 0) { + _internal_set_status(from._internal_status()); + } +} + +void SessionLog::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.SessionLog) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SessionLog::CopyFrom(const SessionLog& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.SessionLog) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SessionLog::IsInitialized() const { + return true; +} + +void SessionLog::InternalSwap(SessionLog* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + checkpoint_path_.Swap(&other->checkpoint_path_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + msg_.Swap(&other->msg_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(status_, other->status_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata SessionLog::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TaggedRunMetadata::InitAsDefaultInstance() { +} +class TaggedRunMetadata::_Internal { + public: +}; + +TaggedRunMetadata::TaggedRunMetadata(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.TaggedRunMetadata) +} +TaggedRunMetadata::TaggedRunMetadata(const TaggedRunMetadata& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + tag_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_tag().empty()) { + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_tag(), + GetArena()); + } + run_metadata_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_run_metadata().empty()) { + run_metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_run_metadata(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:tensorboard.TaggedRunMetadata) +} + +void TaggedRunMetadata::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TaggedRunMetadata_event_2eproto.base); + tag_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + run_metadata_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +TaggedRunMetadata::~TaggedRunMetadata() { + // @@protoc_insertion_point(destructor:tensorboard.TaggedRunMetadata) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TaggedRunMetadata::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + tag_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + run_metadata_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void TaggedRunMetadata::ArenaDtor(void* object) { + TaggedRunMetadata* _this = reinterpret_cast< TaggedRunMetadata* >(object); + (void)_this; +} +void TaggedRunMetadata::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TaggedRunMetadata::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TaggedRunMetadata& TaggedRunMetadata::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TaggedRunMetadata_event_2eproto.base); + return *internal_default_instance(); +} + + +void TaggedRunMetadata::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.TaggedRunMetadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + tag_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + run_metadata_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TaggedRunMetadata::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string tag = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_tag(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.TaggedRunMetadata.tag")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes run_metadata = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_run_metadata(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TaggedRunMetadata::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.TaggedRunMetadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string tag = 1; + if (this->tag().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_tag().data(), static_cast(this->_internal_tag().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.TaggedRunMetadata.tag"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_tag(), target); + } + + // bytes run_metadata = 2; + if (this->run_metadata().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 2, this->_internal_run_metadata(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.TaggedRunMetadata) + return target; +} + +size_t TaggedRunMetadata::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.TaggedRunMetadata) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string tag = 1; + if (this->tag().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_tag()); + } + + // bytes run_metadata = 2; + if (this->run_metadata().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_run_metadata()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TaggedRunMetadata::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.TaggedRunMetadata) + GOOGLE_DCHECK_NE(&from, this); + const TaggedRunMetadata* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.TaggedRunMetadata) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.TaggedRunMetadata) + MergeFrom(*source); + } +} + +void TaggedRunMetadata::MergeFrom(const TaggedRunMetadata& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.TaggedRunMetadata) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.tag().size() > 0) { + _internal_set_tag(from._internal_tag()); + } + if (from.run_metadata().size() > 0) { + _internal_set_run_metadata(from._internal_run_metadata()); + } +} + +void TaggedRunMetadata::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.TaggedRunMetadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TaggedRunMetadata::CopyFrom(const TaggedRunMetadata& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.TaggedRunMetadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TaggedRunMetadata::IsInitialized() const { + return true; +} + +void TaggedRunMetadata::InternalSwap(TaggedRunMetadata* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + tag_.Swap(&other->tag_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + run_metadata_.Swap(&other->run_metadata_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TaggedRunMetadata::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void WatchdogConfig::InitAsDefaultInstance() { +} +class WatchdogConfig::_Internal { + public: +}; + +WatchdogConfig::WatchdogConfig(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.WatchdogConfig) +} +WatchdogConfig::WatchdogConfig(const WatchdogConfig& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + timeout_ms_ = from.timeout_ms_; + // @@protoc_insertion_point(copy_constructor:tensorboard.WatchdogConfig) +} + +void WatchdogConfig::SharedCtor() { + timeout_ms_ = PROTOBUF_LONGLONG(0); +} + +WatchdogConfig::~WatchdogConfig() { + // @@protoc_insertion_point(destructor:tensorboard.WatchdogConfig) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void WatchdogConfig::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void WatchdogConfig::ArenaDtor(void* object) { + WatchdogConfig* _this = reinterpret_cast< WatchdogConfig* >(object); + (void)_this; +} +void WatchdogConfig::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void WatchdogConfig::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const WatchdogConfig& WatchdogConfig::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_WatchdogConfig_event_2eproto.base); + return *internal_default_instance(); +} + + +void WatchdogConfig::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.WatchdogConfig) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + timeout_ms_ = PROTOBUF_LONGLONG(0); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* WatchdogConfig::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // int64 timeout_ms = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + timeout_ms_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* WatchdogConfig::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.WatchdogConfig) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int64 timeout_ms = 1; + if (this->timeout_ms() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_timeout_ms(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.WatchdogConfig) + return target; +} + +size_t WatchdogConfig::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.WatchdogConfig) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // int64 timeout_ms = 1; + if (this->timeout_ms() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_timeout_ms()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void WatchdogConfig::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.WatchdogConfig) + GOOGLE_DCHECK_NE(&from, this); + const WatchdogConfig* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.WatchdogConfig) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.WatchdogConfig) + MergeFrom(*source); + } +} + +void WatchdogConfig::MergeFrom(const WatchdogConfig& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.WatchdogConfig) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.timeout_ms() != 0) { + _internal_set_timeout_ms(from._internal_timeout_ms()); + } +} + +void WatchdogConfig::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.WatchdogConfig) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void WatchdogConfig::CopyFrom(const WatchdogConfig& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.WatchdogConfig) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool WatchdogConfig::IsInitialized() const { + return true; +} + +void WatchdogConfig::InternalSwap(WatchdogConfig* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(timeout_ms_, other->timeout_ms_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata WatchdogConfig::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void RequestedExitCode::InitAsDefaultInstance() { +} +class RequestedExitCode::_Internal { + public: +}; + +RequestedExitCode::RequestedExitCode(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.RequestedExitCode) +} +RequestedExitCode::RequestedExitCode(const RequestedExitCode& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + exit_code_ = from.exit_code_; + // @@protoc_insertion_point(copy_constructor:tensorboard.RequestedExitCode) +} + +void RequestedExitCode::SharedCtor() { + exit_code_ = 0; +} + +RequestedExitCode::~RequestedExitCode() { + // @@protoc_insertion_point(destructor:tensorboard.RequestedExitCode) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void RequestedExitCode::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void RequestedExitCode::ArenaDtor(void* object) { + RequestedExitCode* _this = reinterpret_cast< RequestedExitCode* >(object); + (void)_this; +} +void RequestedExitCode::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void RequestedExitCode::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const RequestedExitCode& RequestedExitCode::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_RequestedExitCode_event_2eproto.base); + return *internal_default_instance(); +} + + +void RequestedExitCode::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.RequestedExitCode) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + exit_code_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* RequestedExitCode::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // int32 exit_code = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + exit_code_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* RequestedExitCode::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.RequestedExitCode) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int32 exit_code = 1; + if (this->exit_code() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(1, this->_internal_exit_code(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.RequestedExitCode) + return target; +} + +size_t RequestedExitCode::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.RequestedExitCode) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // int32 exit_code = 1; + if (this->exit_code() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_exit_code()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void RequestedExitCode::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.RequestedExitCode) + GOOGLE_DCHECK_NE(&from, this); + const RequestedExitCode* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.RequestedExitCode) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.RequestedExitCode) + MergeFrom(*source); + } +} + +void RequestedExitCode::MergeFrom(const RequestedExitCode& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.RequestedExitCode) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.exit_code() != 0) { + _internal_set_exit_code(from._internal_exit_code()); + } +} + +void RequestedExitCode::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.RequestedExitCode) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void RequestedExitCode::CopyFrom(const RequestedExitCode& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.RequestedExitCode) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool RequestedExitCode::IsInitialized() const { + return true; +} + +void RequestedExitCode::InternalSwap(RequestedExitCode* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(exit_code_, other->exit_code_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata RequestedExitCode::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void WorkerHeartbeatRequest::InitAsDefaultInstance() { + ::tensorboard::_WorkerHeartbeatRequest_default_instance_._instance.get_mutable()->watchdog_config_ = const_cast< ::tensorboard::WatchdogConfig*>( + ::tensorboard::WatchdogConfig::internal_default_instance()); + ::tensorboard::_WorkerHeartbeatRequest_default_instance_._instance.get_mutable()->exit_code_ = const_cast< ::tensorboard::RequestedExitCode*>( + ::tensorboard::RequestedExitCode::internal_default_instance()); +} +class WorkerHeartbeatRequest::_Internal { + public: + static const ::tensorboard::WatchdogConfig& watchdog_config(const WorkerHeartbeatRequest* msg); + static const ::tensorboard::RequestedExitCode& exit_code(const WorkerHeartbeatRequest* msg); +}; + +const ::tensorboard::WatchdogConfig& +WorkerHeartbeatRequest::_Internal::watchdog_config(const WorkerHeartbeatRequest* msg) { + return *msg->watchdog_config_; +} +const ::tensorboard::RequestedExitCode& +WorkerHeartbeatRequest::_Internal::exit_code(const WorkerHeartbeatRequest* msg) { + return *msg->exit_code_; +} +WorkerHeartbeatRequest::WorkerHeartbeatRequest(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.WorkerHeartbeatRequest) +} +WorkerHeartbeatRequest::WorkerHeartbeatRequest(const WorkerHeartbeatRequest& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_watchdog_config()) { + watchdog_config_ = new ::tensorboard::WatchdogConfig(*from.watchdog_config_); + } else { + watchdog_config_ = nullptr; + } + if (from._internal_has_exit_code()) { + exit_code_ = new ::tensorboard::RequestedExitCode(*from.exit_code_); + } else { + exit_code_ = nullptr; + } + shutdown_mode_ = from.shutdown_mode_; + // @@protoc_insertion_point(copy_constructor:tensorboard.WorkerHeartbeatRequest) +} + +void WorkerHeartbeatRequest::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_WorkerHeartbeatRequest_event_2eproto.base); + ::memset(&watchdog_config_, 0, static_cast( + reinterpret_cast(&shutdown_mode_) - + reinterpret_cast(&watchdog_config_)) + sizeof(shutdown_mode_)); +} + +WorkerHeartbeatRequest::~WorkerHeartbeatRequest() { + // @@protoc_insertion_point(destructor:tensorboard.WorkerHeartbeatRequest) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void WorkerHeartbeatRequest::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete watchdog_config_; + if (this != internal_default_instance()) delete exit_code_; +} + +void WorkerHeartbeatRequest::ArenaDtor(void* object) { + WorkerHeartbeatRequest* _this = reinterpret_cast< WorkerHeartbeatRequest* >(object); + (void)_this; +} +void WorkerHeartbeatRequest::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void WorkerHeartbeatRequest::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const WorkerHeartbeatRequest& WorkerHeartbeatRequest::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_WorkerHeartbeatRequest_event_2eproto.base); + return *internal_default_instance(); +} + + +void WorkerHeartbeatRequest::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.WorkerHeartbeatRequest) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + if (GetArena() == nullptr && watchdog_config_ != nullptr) { + delete watchdog_config_; + } + watchdog_config_ = nullptr; + if (GetArena() == nullptr && exit_code_ != nullptr) { + delete exit_code_; + } + exit_code_ = nullptr; + shutdown_mode_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* WorkerHeartbeatRequest::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.WorkerShutdownMode shutdown_mode = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_shutdown_mode(static_cast<::tensorboard::WorkerShutdownMode>(val)); + } else goto handle_unusual; + continue; + // .tensorboard.WatchdogConfig watchdog_config = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_watchdog_config(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.RequestedExitCode exit_code = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ctx->ParseMessage(_internal_mutable_exit_code(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* WorkerHeartbeatRequest::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.WorkerHeartbeatRequest) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.WorkerShutdownMode shutdown_mode = 1; + if (this->shutdown_mode() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_shutdown_mode(), target); + } + + // .tensorboard.WatchdogConfig watchdog_config = 2; + if (this->has_watchdog_config()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::watchdog_config(this), target, stream); + } + + // .tensorboard.RequestedExitCode exit_code = 3; + if (this->has_exit_code()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 3, _Internal::exit_code(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.WorkerHeartbeatRequest) + return target; +} + +size_t WorkerHeartbeatRequest::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.WorkerHeartbeatRequest) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // .tensorboard.WatchdogConfig watchdog_config = 2; + if (this->has_watchdog_config()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *watchdog_config_); + } + + // .tensorboard.RequestedExitCode exit_code = 3; + if (this->has_exit_code()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *exit_code_); + } + + // .tensorboard.WorkerShutdownMode shutdown_mode = 1; + if (this->shutdown_mode() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_shutdown_mode()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void WorkerHeartbeatRequest::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.WorkerHeartbeatRequest) + GOOGLE_DCHECK_NE(&from, this); + const WorkerHeartbeatRequest* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.WorkerHeartbeatRequest) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.WorkerHeartbeatRequest) + MergeFrom(*source); + } +} + +void WorkerHeartbeatRequest::MergeFrom(const WorkerHeartbeatRequest& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.WorkerHeartbeatRequest) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.has_watchdog_config()) { + _internal_mutable_watchdog_config()->::tensorboard::WatchdogConfig::MergeFrom(from._internal_watchdog_config()); + } + if (from.has_exit_code()) { + _internal_mutable_exit_code()->::tensorboard::RequestedExitCode::MergeFrom(from._internal_exit_code()); + } + if (from.shutdown_mode() != 0) { + _internal_set_shutdown_mode(from._internal_shutdown_mode()); + } +} + +void WorkerHeartbeatRequest::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.WorkerHeartbeatRequest) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void WorkerHeartbeatRequest::CopyFrom(const WorkerHeartbeatRequest& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.WorkerHeartbeatRequest) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool WorkerHeartbeatRequest::IsInitialized() const { + return true; +} + +void WorkerHeartbeatRequest::InternalSwap(WorkerHeartbeatRequest* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(WorkerHeartbeatRequest, shutdown_mode_) + + sizeof(WorkerHeartbeatRequest::shutdown_mode_) + - PROTOBUF_FIELD_OFFSET(WorkerHeartbeatRequest, watchdog_config_)>( + reinterpret_cast(&watchdog_config_), + reinterpret_cast(&other->watchdog_config_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata WorkerHeartbeatRequest::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void WorkerHeartbeatResponse::InitAsDefaultInstance() { +} +class WorkerHeartbeatResponse::_Internal { + public: +}; + +WorkerHeartbeatResponse::WorkerHeartbeatResponse(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + worker_log_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.WorkerHeartbeatResponse) +} +WorkerHeartbeatResponse::WorkerHeartbeatResponse(const WorkerHeartbeatResponse& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + worker_log_(from.worker_log_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + hostname_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_hostname().empty()) { + hostname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_hostname(), + GetArena()); + } + health_status_ = from.health_status_; + // @@protoc_insertion_point(copy_constructor:tensorboard.WorkerHeartbeatResponse) +} + +void WorkerHeartbeatResponse::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_WorkerHeartbeatResponse_event_2eproto.base); + hostname_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + health_status_ = 0; +} + +WorkerHeartbeatResponse::~WorkerHeartbeatResponse() { + // @@protoc_insertion_point(destructor:tensorboard.WorkerHeartbeatResponse) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void WorkerHeartbeatResponse::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + hostname_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void WorkerHeartbeatResponse::ArenaDtor(void* object) { + WorkerHeartbeatResponse* _this = reinterpret_cast< WorkerHeartbeatResponse* >(object); + (void)_this; +} +void WorkerHeartbeatResponse::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void WorkerHeartbeatResponse::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const WorkerHeartbeatResponse& WorkerHeartbeatResponse::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_WorkerHeartbeatResponse_event_2eproto.base); + return *internal_default_instance(); +} + + +void WorkerHeartbeatResponse::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.WorkerHeartbeatResponse) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + worker_log_.Clear(); + hostname_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + health_status_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* WorkerHeartbeatResponse::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.WorkerHealth health_status = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_health_status(static_cast<::tensorboard::WorkerHealth>(val)); + } else goto handle_unusual; + continue; + // repeated .tensorboard.Event worker_log = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_worker_log(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<18>(ptr)); + } else goto handle_unusual; + continue; + // string hostname = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_hostname(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.WorkerHeartbeatResponse.hostname")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* WorkerHeartbeatResponse::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.WorkerHeartbeatResponse) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.WorkerHealth health_status = 1; + if (this->health_status() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_health_status(), target); + } + + // repeated .tensorboard.Event worker_log = 2; + for (unsigned int i = 0, + n = static_cast(this->_internal_worker_log_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(2, this->_internal_worker_log(i), target, stream); + } + + // string hostname = 3; + if (this->hostname().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_hostname().data(), static_cast(this->_internal_hostname().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.WorkerHeartbeatResponse.hostname"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_hostname(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.WorkerHeartbeatResponse) + return target; +} + +size_t WorkerHeartbeatResponse::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.WorkerHeartbeatResponse) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .tensorboard.Event worker_log = 2; + total_size += 1UL * this->_internal_worker_log_size(); + for (const auto& msg : this->worker_log_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // string hostname = 3; + if (this->hostname().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_hostname()); + } + + // .tensorboard.WorkerHealth health_status = 1; + if (this->health_status() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_health_status()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void WorkerHeartbeatResponse::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.WorkerHeartbeatResponse) + GOOGLE_DCHECK_NE(&from, this); + const WorkerHeartbeatResponse* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.WorkerHeartbeatResponse) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.WorkerHeartbeatResponse) + MergeFrom(*source); + } +} + +void WorkerHeartbeatResponse::MergeFrom(const WorkerHeartbeatResponse& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.WorkerHeartbeatResponse) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + worker_log_.MergeFrom(from.worker_log_); + if (from.hostname().size() > 0) { + _internal_set_hostname(from._internal_hostname()); + } + if (from.health_status() != 0) { + _internal_set_health_status(from._internal_health_status()); + } +} + +void WorkerHeartbeatResponse::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.WorkerHeartbeatResponse) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void WorkerHeartbeatResponse::CopyFrom(const WorkerHeartbeatResponse& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.WorkerHeartbeatResponse) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool WorkerHeartbeatResponse::IsInitialized() const { + return true; +} + +void WorkerHeartbeatResponse::InternalSwap(WorkerHeartbeatResponse* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + worker_log_.InternalSwap(&other->worker_log_); + hostname_.Swap(&other->hostname_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(health_status_, other->health_status_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata WorkerHeartbeatResponse::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::Event* Arena::CreateMaybeMessage< ::tensorboard::Event >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::Event >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::SourceMetadata* Arena::CreateMaybeMessage< ::tensorboard::SourceMetadata >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::SourceMetadata >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::LogMessage* Arena::CreateMaybeMessage< ::tensorboard::LogMessage >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::LogMessage >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::SessionLog* Arena::CreateMaybeMessage< ::tensorboard::SessionLog >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::SessionLog >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::TaggedRunMetadata* Arena::CreateMaybeMessage< ::tensorboard::TaggedRunMetadata >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::TaggedRunMetadata >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::WatchdogConfig* Arena::CreateMaybeMessage< ::tensorboard::WatchdogConfig >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::WatchdogConfig >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::RequestedExitCode* Arena::CreateMaybeMessage< ::tensorboard::RequestedExitCode >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::RequestedExitCode >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::WorkerHeartbeatRequest* Arena::CreateMaybeMessage< ::tensorboard::WorkerHeartbeatRequest >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::WorkerHeartbeatRequest >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::WorkerHeartbeatResponse* Arena::CreateMaybeMessage< ::tensorboard::WorkerHeartbeatResponse >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::WorkerHeartbeatResponse >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/event.pb.h b/plugins/mindstudio-insight-plugins/proto/event.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..2cf29fbaca77b30c3ad5506cacdbad4d1d742afb --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/event.pb.h @@ -0,0 +1,3470 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: event.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_event_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_event_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include "summary.pb.h" +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_event_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_event_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[9] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_event_2eproto; +namespace tensorboard { +class Event; +class EventDefaultTypeInternal; +extern EventDefaultTypeInternal _Event_default_instance_; +class LogMessage; +class LogMessageDefaultTypeInternal; +extern LogMessageDefaultTypeInternal _LogMessage_default_instance_; +class RequestedExitCode; +class RequestedExitCodeDefaultTypeInternal; +extern RequestedExitCodeDefaultTypeInternal _RequestedExitCode_default_instance_; +class SessionLog; +class SessionLogDefaultTypeInternal; +extern SessionLogDefaultTypeInternal _SessionLog_default_instance_; +class SourceMetadata; +class SourceMetadataDefaultTypeInternal; +extern SourceMetadataDefaultTypeInternal _SourceMetadata_default_instance_; +class TaggedRunMetadata; +class TaggedRunMetadataDefaultTypeInternal; +extern TaggedRunMetadataDefaultTypeInternal _TaggedRunMetadata_default_instance_; +class WatchdogConfig; +class WatchdogConfigDefaultTypeInternal; +extern WatchdogConfigDefaultTypeInternal _WatchdogConfig_default_instance_; +class WorkerHeartbeatRequest; +class WorkerHeartbeatRequestDefaultTypeInternal; +extern WorkerHeartbeatRequestDefaultTypeInternal _WorkerHeartbeatRequest_default_instance_; +class WorkerHeartbeatResponse; +class WorkerHeartbeatResponseDefaultTypeInternal; +extern WorkerHeartbeatResponseDefaultTypeInternal _WorkerHeartbeatResponse_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::Event* Arena::CreateMaybeMessage<::tensorboard::Event>(Arena*); +template<> ::tensorboard::LogMessage* Arena::CreateMaybeMessage<::tensorboard::LogMessage>(Arena*); +template<> ::tensorboard::RequestedExitCode* Arena::CreateMaybeMessage<::tensorboard::RequestedExitCode>(Arena*); +template<> ::tensorboard::SessionLog* Arena::CreateMaybeMessage<::tensorboard::SessionLog>(Arena*); +template<> ::tensorboard::SourceMetadata* Arena::CreateMaybeMessage<::tensorboard::SourceMetadata>(Arena*); +template<> ::tensorboard::TaggedRunMetadata* Arena::CreateMaybeMessage<::tensorboard::TaggedRunMetadata>(Arena*); +template<> ::tensorboard::WatchdogConfig* Arena::CreateMaybeMessage<::tensorboard::WatchdogConfig>(Arena*); +template<> ::tensorboard::WorkerHeartbeatRequest* Arena::CreateMaybeMessage<::tensorboard::WorkerHeartbeatRequest>(Arena*); +template<> ::tensorboard::WorkerHeartbeatResponse* Arena::CreateMaybeMessage<::tensorboard::WorkerHeartbeatResponse>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +enum LogMessage_Level : int { + LogMessage_Level_UNKNOWN = 0, + LogMessage_Level_DEBUGGING = 10, + LogMessage_Level_INFO = 20, + LogMessage_Level_WARN = 30, + LogMessage_Level_ERROR = 40, + LogMessage_Level_FATAL = 50, + LogMessage_Level_LogMessage_Level_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + LogMessage_Level_LogMessage_Level_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +bool LogMessage_Level_IsValid(int value); +constexpr LogMessage_Level LogMessage_Level_Level_MIN = LogMessage_Level_UNKNOWN; +constexpr LogMessage_Level LogMessage_Level_Level_MAX = LogMessage_Level_FATAL; +constexpr int LogMessage_Level_Level_ARRAYSIZE = LogMessage_Level_Level_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* LogMessage_Level_descriptor(); +template +inline const std::string& LogMessage_Level_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function LogMessage_Level_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + LogMessage_Level_descriptor(), enum_t_value); +} +inline bool LogMessage_Level_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, LogMessage_Level* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + LogMessage_Level_descriptor(), name, value); +} +enum SessionLog_SessionStatus : int { + SessionLog_SessionStatus_STATUS_UNSPECIFIED = 0, + SessionLog_SessionStatus_START = 1, + SessionLog_SessionStatus_STOP = 2, + SessionLog_SessionStatus_CHECKPOINT = 3, + SessionLog_SessionStatus_SessionLog_SessionStatus_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + SessionLog_SessionStatus_SessionLog_SessionStatus_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +bool SessionLog_SessionStatus_IsValid(int value); +constexpr SessionLog_SessionStatus SessionLog_SessionStatus_SessionStatus_MIN = SessionLog_SessionStatus_STATUS_UNSPECIFIED; +constexpr SessionLog_SessionStatus SessionLog_SessionStatus_SessionStatus_MAX = SessionLog_SessionStatus_CHECKPOINT; +constexpr int SessionLog_SessionStatus_SessionStatus_ARRAYSIZE = SessionLog_SessionStatus_SessionStatus_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* SessionLog_SessionStatus_descriptor(); +template +inline const std::string& SessionLog_SessionStatus_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function SessionLog_SessionStatus_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + SessionLog_SessionStatus_descriptor(), enum_t_value); +} +inline bool SessionLog_SessionStatus_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, SessionLog_SessionStatus* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + SessionLog_SessionStatus_descriptor(), name, value); +} +enum WorkerHealth : int { + OK = 0, + RECEIVED_SHUTDOWN_SIGNAL = 1, + INTERNAL_ERROR = 2, + SHUTTING_DOWN = 3, + WorkerHealth_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + WorkerHealth_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +bool WorkerHealth_IsValid(int value); +constexpr WorkerHealth WorkerHealth_MIN = OK; +constexpr WorkerHealth WorkerHealth_MAX = SHUTTING_DOWN; +constexpr int WorkerHealth_ARRAYSIZE = WorkerHealth_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* WorkerHealth_descriptor(); +template +inline const std::string& WorkerHealth_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function WorkerHealth_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + WorkerHealth_descriptor(), enum_t_value); +} +inline bool WorkerHealth_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, WorkerHealth* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + WorkerHealth_descriptor(), name, value); +} +enum WorkerShutdownMode : int { + DEFAULT = 0, + NOT_CONFIGURED = 1, + WAIT_FOR_COORDINATOR = 2, + SHUTDOWN_AFTER_TIMEOUT = 3, + WorkerShutdownMode_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + WorkerShutdownMode_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +bool WorkerShutdownMode_IsValid(int value); +constexpr WorkerShutdownMode WorkerShutdownMode_MIN = DEFAULT; +constexpr WorkerShutdownMode WorkerShutdownMode_MAX = SHUTDOWN_AFTER_TIMEOUT; +constexpr int WorkerShutdownMode_ARRAYSIZE = WorkerShutdownMode_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* WorkerShutdownMode_descriptor(); +template +inline const std::string& WorkerShutdownMode_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function WorkerShutdownMode_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + WorkerShutdownMode_descriptor(), enum_t_value); +} +inline bool WorkerShutdownMode_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, WorkerShutdownMode* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + WorkerShutdownMode_descriptor(), name, value); +} +// =================================================================== + +class Event PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.Event) */ { + public: + inline Event() : Event(nullptr) {} + virtual ~Event(); + + Event(const Event& from); + Event(Event&& from) noexcept + : Event() { + *this = ::std::move(from); + } + + inline Event& operator=(const Event& from) { + CopyFrom(from); + return *this; + } + inline Event& operator=(Event&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Event& default_instance(); + + enum WhatCase { + kFileVersion = 3, + kGraphDef = 4, + kSummary = 5, + kLogMessage = 6, + kSessionLog = 7, + kTaggedRunMetadata = 8, + kMetaGraphDef = 9, + WHAT_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Event* internal_default_instance() { + return reinterpret_cast( + &_Event_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Event& a, Event& b) { + a.Swap(&b); + } + inline void Swap(Event* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Event* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Event* New() const final { + return CreateMaybeMessage(nullptr); + } + + Event* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Event& from); + void MergeFrom(const Event& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Event* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.Event"; + } + protected: + explicit Event(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kSourceMetadataFieldNumber = 10, + kWallTimeFieldNumber = 1, + kStepFieldNumber = 2, + kFileVersionFieldNumber = 3, + kGraphDefFieldNumber = 4, + kSummaryFieldNumber = 5, + kLogMessageFieldNumber = 6, + kSessionLogFieldNumber = 7, + kTaggedRunMetadataFieldNumber = 8, + kMetaGraphDefFieldNumber = 9, + }; + // .tensorboard.SourceMetadata source_metadata = 10; + bool has_source_metadata() const; + private: + bool _internal_has_source_metadata() const; + public: + void clear_source_metadata(); + const ::tensorboard::SourceMetadata& source_metadata() const; + ::tensorboard::SourceMetadata* release_source_metadata(); + ::tensorboard::SourceMetadata* mutable_source_metadata(); + void set_allocated_source_metadata(::tensorboard::SourceMetadata* source_metadata); + private: + const ::tensorboard::SourceMetadata& _internal_source_metadata() const; + ::tensorboard::SourceMetadata* _internal_mutable_source_metadata(); + public: + void unsafe_arena_set_allocated_source_metadata( + ::tensorboard::SourceMetadata* source_metadata); + ::tensorboard::SourceMetadata* unsafe_arena_release_source_metadata(); + + // double wall_time = 1; + void clear_wall_time(); + double wall_time() const; + void set_wall_time(double value); + private: + double _internal_wall_time() const; + void _internal_set_wall_time(double value); + public: + + // int64 step = 2; + void clear_step(); + ::PROTOBUF_NAMESPACE_ID::int64 step() const; + void set_step(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_step() const; + void _internal_set_step(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // string file_version = 3; + private: + bool _internal_has_file_version() const; + public: + void clear_file_version(); + const std::string& file_version() const; + void set_file_version(const std::string& value); + void set_file_version(std::string&& value); + void set_file_version(const char* value); + void set_file_version(const char* value, size_t size); + std::string* mutable_file_version(); + std::string* release_file_version(); + void set_allocated_file_version(std::string* file_version); + private: + const std::string& _internal_file_version() const; + void _internal_set_file_version(const std::string& value); + std::string* _internal_mutable_file_version(); + public: + + // bytes graph_def = 4; + private: + bool _internal_has_graph_def() const; + public: + void clear_graph_def(); + const std::string& graph_def() const; + void set_graph_def(const std::string& value); + void set_graph_def(std::string&& value); + void set_graph_def(const char* value); + void set_graph_def(const void* value, size_t size); + std::string* mutable_graph_def(); + std::string* release_graph_def(); + void set_allocated_graph_def(std::string* graph_def); + private: + const std::string& _internal_graph_def() const; + void _internal_set_graph_def(const std::string& value); + std::string* _internal_mutable_graph_def(); + public: + + // .tensorboard.Summary summary = 5; + bool has_summary() const; + private: + bool _internal_has_summary() const; + public: + void clear_summary(); + const ::tensorboard::Summary& summary() const; + ::tensorboard::Summary* release_summary(); + ::tensorboard::Summary* mutable_summary(); + void set_allocated_summary(::tensorboard::Summary* summary); + private: + const ::tensorboard::Summary& _internal_summary() const; + ::tensorboard::Summary* _internal_mutable_summary(); + public: + void unsafe_arena_set_allocated_summary( + ::tensorboard::Summary* summary); + ::tensorboard::Summary* unsafe_arena_release_summary(); + + // .tensorboard.LogMessage log_message = 6 [deprecated = true]; + PROTOBUF_DEPRECATED bool has_log_message() const; + private: + bool _internal_has_log_message() const; + public: + PROTOBUF_DEPRECATED void clear_log_message(); + PROTOBUF_DEPRECATED const ::tensorboard::LogMessage& log_message() const; + PROTOBUF_DEPRECATED ::tensorboard::LogMessage* release_log_message(); + PROTOBUF_DEPRECATED ::tensorboard::LogMessage* mutable_log_message(); + PROTOBUF_DEPRECATED void set_allocated_log_message(::tensorboard::LogMessage* log_message); + private: + const ::tensorboard::LogMessage& _internal_log_message() const; + ::tensorboard::LogMessage* _internal_mutable_log_message(); + public: + PROTOBUF_DEPRECATED void unsafe_arena_set_allocated_log_message( + ::tensorboard::LogMessage* log_message); + PROTOBUF_DEPRECATED ::tensorboard::LogMessage* unsafe_arena_release_log_message(); + + // .tensorboard.SessionLog session_log = 7; + bool has_session_log() const; + private: + bool _internal_has_session_log() const; + public: + void clear_session_log(); + const ::tensorboard::SessionLog& session_log() const; + ::tensorboard::SessionLog* release_session_log(); + ::tensorboard::SessionLog* mutable_session_log(); + void set_allocated_session_log(::tensorboard::SessionLog* session_log); + private: + const ::tensorboard::SessionLog& _internal_session_log() const; + ::tensorboard::SessionLog* _internal_mutable_session_log(); + public: + void unsafe_arena_set_allocated_session_log( + ::tensorboard::SessionLog* session_log); + ::tensorboard::SessionLog* unsafe_arena_release_session_log(); + + // .tensorboard.TaggedRunMetadata tagged_run_metadata = 8; + bool has_tagged_run_metadata() const; + private: + bool _internal_has_tagged_run_metadata() const; + public: + void clear_tagged_run_metadata(); + const ::tensorboard::TaggedRunMetadata& tagged_run_metadata() const; + ::tensorboard::TaggedRunMetadata* release_tagged_run_metadata(); + ::tensorboard::TaggedRunMetadata* mutable_tagged_run_metadata(); + void set_allocated_tagged_run_metadata(::tensorboard::TaggedRunMetadata* tagged_run_metadata); + private: + const ::tensorboard::TaggedRunMetadata& _internal_tagged_run_metadata() const; + ::tensorboard::TaggedRunMetadata* _internal_mutable_tagged_run_metadata(); + public: + void unsafe_arena_set_allocated_tagged_run_metadata( + ::tensorboard::TaggedRunMetadata* tagged_run_metadata); + ::tensorboard::TaggedRunMetadata* unsafe_arena_release_tagged_run_metadata(); + + // bytes meta_graph_def = 9; + private: + bool _internal_has_meta_graph_def() const; + public: + void clear_meta_graph_def(); + const std::string& meta_graph_def() const; + void set_meta_graph_def(const std::string& value); + void set_meta_graph_def(std::string&& value); + void set_meta_graph_def(const char* value); + void set_meta_graph_def(const void* value, size_t size); + std::string* mutable_meta_graph_def(); + std::string* release_meta_graph_def(); + void set_allocated_meta_graph_def(std::string* meta_graph_def); + private: + const std::string& _internal_meta_graph_def() const; + void _internal_set_meta_graph_def(const std::string& value); + std::string* _internal_mutable_meta_graph_def(); + public: + + void clear_what(); + WhatCase what_case() const; + // @@protoc_insertion_point(class_scope:tensorboard.Event) + private: + class _Internal; + void set_has_file_version(); + void set_has_graph_def(); + void set_has_summary(); + void set_has_log_message(); + void set_has_session_log(); + void set_has_tagged_run_metadata(); + void set_has_meta_graph_def(); + + inline bool has_what() const; + inline void clear_has_what(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::tensorboard::SourceMetadata* source_metadata_; + double wall_time_; + ::PROTOBUF_NAMESPACE_ID::int64 step_; + union WhatUnion { + WhatUnion() {} + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr file_version_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr graph_def_; + ::tensorboard::Summary* summary_; + ::tensorboard::LogMessage* log_message_; + ::tensorboard::SessionLog* session_log_; + ::tensorboard::TaggedRunMetadata* tagged_run_metadata_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr meta_graph_def_; + } what_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class SourceMetadata PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.SourceMetadata) */ { + public: + inline SourceMetadata() : SourceMetadata(nullptr) {} + virtual ~SourceMetadata(); + + SourceMetadata(const SourceMetadata& from); + SourceMetadata(SourceMetadata&& from) noexcept + : SourceMetadata() { + *this = ::std::move(from); + } + + inline SourceMetadata& operator=(const SourceMetadata& from) { + CopyFrom(from); + return *this; + } + inline SourceMetadata& operator=(SourceMetadata&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SourceMetadata& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SourceMetadata* internal_default_instance() { + return reinterpret_cast( + &_SourceMetadata_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(SourceMetadata& a, SourceMetadata& b) { + a.Swap(&b); + } + inline void Swap(SourceMetadata* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SourceMetadata* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SourceMetadata* New() const final { + return CreateMaybeMessage(nullptr); + } + + SourceMetadata* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SourceMetadata& from); + void MergeFrom(const SourceMetadata& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SourceMetadata* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.SourceMetadata"; + } + protected: + explicit SourceMetadata(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kWriterFieldNumber = 1, + }; + // string writer = 1; + void clear_writer(); + const std::string& writer() const; + void set_writer(const std::string& value); + void set_writer(std::string&& value); + void set_writer(const char* value); + void set_writer(const char* value, size_t size); + std::string* mutable_writer(); + std::string* release_writer(); + void set_allocated_writer(std::string* writer); + private: + const std::string& _internal_writer() const; + void _internal_set_writer(const std::string& value); + std::string* _internal_mutable_writer(); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.SourceMetadata) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr writer_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class LogMessage PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.LogMessage) */ { + public: + inline LogMessage() : LogMessage(nullptr) {} + virtual ~LogMessage(); + + LogMessage(const LogMessage& from); + LogMessage(LogMessage&& from) noexcept + : LogMessage() { + *this = ::std::move(from); + } + + inline LogMessage& operator=(const LogMessage& from) { + CopyFrom(from); + return *this; + } + inline LogMessage& operator=(LogMessage&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const LogMessage& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const LogMessage* internal_default_instance() { + return reinterpret_cast( + &_LogMessage_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(LogMessage& a, LogMessage& b) { + a.Swap(&b); + } + inline void Swap(LogMessage* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(LogMessage* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline LogMessage* New() const final { + return CreateMaybeMessage(nullptr); + } + + LogMessage* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const LogMessage& from); + void MergeFrom(const LogMessage& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(LogMessage* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.LogMessage"; + } + protected: + explicit LogMessage(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef LogMessage_Level Level; + static constexpr Level UNKNOWN = + LogMessage_Level_UNKNOWN; + static constexpr Level DEBUGGING = + LogMessage_Level_DEBUGGING; + static constexpr Level INFO = + LogMessage_Level_INFO; + static constexpr Level WARN = + LogMessage_Level_WARN; + static constexpr Level ERROR = + LogMessage_Level_ERROR; + static constexpr Level FATAL = + LogMessage_Level_FATAL; + static inline bool Level_IsValid(int value) { + return LogMessage_Level_IsValid(value); + } + static constexpr Level Level_MIN = + LogMessage_Level_Level_MIN; + static constexpr Level Level_MAX = + LogMessage_Level_Level_MAX; + static constexpr int Level_ARRAYSIZE = + LogMessage_Level_Level_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + Level_descriptor() { + return LogMessage_Level_descriptor(); + } + template + static inline const std::string& Level_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Level_Name."); + return LogMessage_Level_Name(enum_t_value); + } + static inline bool Level_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + Level* value) { + return LogMessage_Level_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kMessageFieldNumber = 2, + kLevelFieldNumber = 1, + }; + // string message = 2; + void clear_message(); + const std::string& message() const; + void set_message(const std::string& value); + void set_message(std::string&& value); + void set_message(const char* value); + void set_message(const char* value, size_t size); + std::string* mutable_message(); + std::string* release_message(); + void set_allocated_message(std::string* message); + private: + const std::string& _internal_message() const; + void _internal_set_message(const std::string& value); + std::string* _internal_mutable_message(); + public: + + // .tensorboard.LogMessage.Level level = 1; + void clear_level(); + ::tensorboard::LogMessage_Level level() const; + void set_level(::tensorboard::LogMessage_Level value); + private: + ::tensorboard::LogMessage_Level _internal_level() const; + void _internal_set_level(::tensorboard::LogMessage_Level value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.LogMessage) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr message_; + int level_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class SessionLog PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.SessionLog) */ { + public: + inline SessionLog() : SessionLog(nullptr) {} + virtual ~SessionLog(); + + SessionLog(const SessionLog& from); + SessionLog(SessionLog&& from) noexcept + : SessionLog() { + *this = ::std::move(from); + } + + inline SessionLog& operator=(const SessionLog& from) { + CopyFrom(from); + return *this; + } + inline SessionLog& operator=(SessionLog&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SessionLog& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SessionLog* internal_default_instance() { + return reinterpret_cast( + &_SessionLog_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(SessionLog& a, SessionLog& b) { + a.Swap(&b); + } + inline void Swap(SessionLog* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SessionLog* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SessionLog* New() const final { + return CreateMaybeMessage(nullptr); + } + + SessionLog* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SessionLog& from); + void MergeFrom(const SessionLog& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SessionLog* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.SessionLog"; + } + protected: + explicit SessionLog(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef SessionLog_SessionStatus SessionStatus; + static constexpr SessionStatus STATUS_UNSPECIFIED = + SessionLog_SessionStatus_STATUS_UNSPECIFIED; + static constexpr SessionStatus START = + SessionLog_SessionStatus_START; + static constexpr SessionStatus STOP = + SessionLog_SessionStatus_STOP; + static constexpr SessionStatus CHECKPOINT = + SessionLog_SessionStatus_CHECKPOINT; + static inline bool SessionStatus_IsValid(int value) { + return SessionLog_SessionStatus_IsValid(value); + } + static constexpr SessionStatus SessionStatus_MIN = + SessionLog_SessionStatus_SessionStatus_MIN; + static constexpr SessionStatus SessionStatus_MAX = + SessionLog_SessionStatus_SessionStatus_MAX; + static constexpr int SessionStatus_ARRAYSIZE = + SessionLog_SessionStatus_SessionStatus_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + SessionStatus_descriptor() { + return SessionLog_SessionStatus_descriptor(); + } + template + static inline const std::string& SessionStatus_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function SessionStatus_Name."); + return SessionLog_SessionStatus_Name(enum_t_value); + } + static inline bool SessionStatus_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + SessionStatus* value) { + return SessionLog_SessionStatus_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kCheckpointPathFieldNumber = 2, + kMsgFieldNumber = 3, + kStatusFieldNumber = 1, + }; + // string checkpoint_path = 2; + void clear_checkpoint_path(); + const std::string& checkpoint_path() const; + void set_checkpoint_path(const std::string& value); + void set_checkpoint_path(std::string&& value); + void set_checkpoint_path(const char* value); + void set_checkpoint_path(const char* value, size_t size); + std::string* mutable_checkpoint_path(); + std::string* release_checkpoint_path(); + void set_allocated_checkpoint_path(std::string* checkpoint_path); + private: + const std::string& _internal_checkpoint_path() const; + void _internal_set_checkpoint_path(const std::string& value); + std::string* _internal_mutable_checkpoint_path(); + public: + + // string msg = 3; + void clear_msg(); + const std::string& msg() const; + void set_msg(const std::string& value); + void set_msg(std::string&& value); + void set_msg(const char* value); + void set_msg(const char* value, size_t size); + std::string* mutable_msg(); + std::string* release_msg(); + void set_allocated_msg(std::string* msg); + private: + const std::string& _internal_msg() const; + void _internal_set_msg(const std::string& value); + std::string* _internal_mutable_msg(); + public: + + // .tensorboard.SessionLog.SessionStatus status = 1; + void clear_status(); + ::tensorboard::SessionLog_SessionStatus status() const; + void set_status(::tensorboard::SessionLog_SessionStatus value); + private: + ::tensorboard::SessionLog_SessionStatus _internal_status() const; + void _internal_set_status(::tensorboard::SessionLog_SessionStatus value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.SessionLog) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr checkpoint_path_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr msg_; + int status_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class TaggedRunMetadata PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.TaggedRunMetadata) */ { + public: + inline TaggedRunMetadata() : TaggedRunMetadata(nullptr) {} + virtual ~TaggedRunMetadata(); + + TaggedRunMetadata(const TaggedRunMetadata& from); + TaggedRunMetadata(TaggedRunMetadata&& from) noexcept + : TaggedRunMetadata() { + *this = ::std::move(from); + } + + inline TaggedRunMetadata& operator=(const TaggedRunMetadata& from) { + CopyFrom(from); + return *this; + } + inline TaggedRunMetadata& operator=(TaggedRunMetadata&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TaggedRunMetadata& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TaggedRunMetadata* internal_default_instance() { + return reinterpret_cast( + &_TaggedRunMetadata_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(TaggedRunMetadata& a, TaggedRunMetadata& b) { + a.Swap(&b); + } + inline void Swap(TaggedRunMetadata* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TaggedRunMetadata* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TaggedRunMetadata* New() const final { + return CreateMaybeMessage(nullptr); + } + + TaggedRunMetadata* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TaggedRunMetadata& from); + void MergeFrom(const TaggedRunMetadata& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TaggedRunMetadata* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.TaggedRunMetadata"; + } + protected: + explicit TaggedRunMetadata(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTagFieldNumber = 1, + kRunMetadataFieldNumber = 2, + }; + // string tag = 1; + void clear_tag(); + const std::string& tag() const; + void set_tag(const std::string& value); + void set_tag(std::string&& value); + void set_tag(const char* value); + void set_tag(const char* value, size_t size); + std::string* mutable_tag(); + std::string* release_tag(); + void set_allocated_tag(std::string* tag); + private: + const std::string& _internal_tag() const; + void _internal_set_tag(const std::string& value); + std::string* _internal_mutable_tag(); + public: + + // bytes run_metadata = 2; + void clear_run_metadata(); + const std::string& run_metadata() const; + void set_run_metadata(const std::string& value); + void set_run_metadata(std::string&& value); + void set_run_metadata(const char* value); + void set_run_metadata(const void* value, size_t size); + std::string* mutable_run_metadata(); + std::string* release_run_metadata(); + void set_allocated_run_metadata(std::string* run_metadata); + private: + const std::string& _internal_run_metadata() const; + void _internal_set_run_metadata(const std::string& value); + std::string* _internal_mutable_run_metadata(); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.TaggedRunMetadata) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr tag_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr run_metadata_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class WatchdogConfig PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.WatchdogConfig) */ { + public: + inline WatchdogConfig() : WatchdogConfig(nullptr) {} + virtual ~WatchdogConfig(); + + WatchdogConfig(const WatchdogConfig& from); + WatchdogConfig(WatchdogConfig&& from) noexcept + : WatchdogConfig() { + *this = ::std::move(from); + } + + inline WatchdogConfig& operator=(const WatchdogConfig& from) { + CopyFrom(from); + return *this; + } + inline WatchdogConfig& operator=(WatchdogConfig&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const WatchdogConfig& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const WatchdogConfig* internal_default_instance() { + return reinterpret_cast( + &_WatchdogConfig_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(WatchdogConfig& a, WatchdogConfig& b) { + a.Swap(&b); + } + inline void Swap(WatchdogConfig* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(WatchdogConfig* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline WatchdogConfig* New() const final { + return CreateMaybeMessage(nullptr); + } + + WatchdogConfig* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const WatchdogConfig& from); + void MergeFrom(const WatchdogConfig& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(WatchdogConfig* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.WatchdogConfig"; + } + protected: + explicit WatchdogConfig(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTimeoutMsFieldNumber = 1, + }; + // int64 timeout_ms = 1; + void clear_timeout_ms(); + ::PROTOBUF_NAMESPACE_ID::int64 timeout_ms() const; + void set_timeout_ms(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_timeout_ms() const; + void _internal_set_timeout_ms(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.WatchdogConfig) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::int64 timeout_ms_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class RequestedExitCode PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.RequestedExitCode) */ { + public: + inline RequestedExitCode() : RequestedExitCode(nullptr) {} + virtual ~RequestedExitCode(); + + RequestedExitCode(const RequestedExitCode& from); + RequestedExitCode(RequestedExitCode&& from) noexcept + : RequestedExitCode() { + *this = ::std::move(from); + } + + inline RequestedExitCode& operator=(const RequestedExitCode& from) { + CopyFrom(from); + return *this; + } + inline RequestedExitCode& operator=(RequestedExitCode&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const RequestedExitCode& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const RequestedExitCode* internal_default_instance() { + return reinterpret_cast( + &_RequestedExitCode_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(RequestedExitCode& a, RequestedExitCode& b) { + a.Swap(&b); + } + inline void Swap(RequestedExitCode* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(RequestedExitCode* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline RequestedExitCode* New() const final { + return CreateMaybeMessage(nullptr); + } + + RequestedExitCode* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const RequestedExitCode& from); + void MergeFrom(const RequestedExitCode& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(RequestedExitCode* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.RequestedExitCode"; + } + protected: + explicit RequestedExitCode(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kExitCodeFieldNumber = 1, + }; + // int32 exit_code = 1; + void clear_exit_code(); + ::PROTOBUF_NAMESPACE_ID::int32 exit_code() const; + void set_exit_code(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_exit_code() const; + void _internal_set_exit_code(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.RequestedExitCode) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::int32 exit_code_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class WorkerHeartbeatRequest PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.WorkerHeartbeatRequest) */ { + public: + inline WorkerHeartbeatRequest() : WorkerHeartbeatRequest(nullptr) {} + virtual ~WorkerHeartbeatRequest(); + + WorkerHeartbeatRequest(const WorkerHeartbeatRequest& from); + WorkerHeartbeatRequest(WorkerHeartbeatRequest&& from) noexcept + : WorkerHeartbeatRequest() { + *this = ::std::move(from); + } + + inline WorkerHeartbeatRequest& operator=(const WorkerHeartbeatRequest& from) { + CopyFrom(from); + return *this; + } + inline WorkerHeartbeatRequest& operator=(WorkerHeartbeatRequest&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const WorkerHeartbeatRequest& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const WorkerHeartbeatRequest* internal_default_instance() { + return reinterpret_cast( + &_WorkerHeartbeatRequest_default_instance_); + } + static constexpr int kIndexInFileMessages = + 7; + + friend void swap(WorkerHeartbeatRequest& a, WorkerHeartbeatRequest& b) { + a.Swap(&b); + } + inline void Swap(WorkerHeartbeatRequest* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(WorkerHeartbeatRequest* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline WorkerHeartbeatRequest* New() const final { + return CreateMaybeMessage(nullptr); + } + + WorkerHeartbeatRequest* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const WorkerHeartbeatRequest& from); + void MergeFrom(const WorkerHeartbeatRequest& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(WorkerHeartbeatRequest* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.WorkerHeartbeatRequest"; + } + protected: + explicit WorkerHeartbeatRequest(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kWatchdogConfigFieldNumber = 2, + kExitCodeFieldNumber = 3, + kShutdownModeFieldNumber = 1, + }; + // .tensorboard.WatchdogConfig watchdog_config = 2; + bool has_watchdog_config() const; + private: + bool _internal_has_watchdog_config() const; + public: + void clear_watchdog_config(); + const ::tensorboard::WatchdogConfig& watchdog_config() const; + ::tensorboard::WatchdogConfig* release_watchdog_config(); + ::tensorboard::WatchdogConfig* mutable_watchdog_config(); + void set_allocated_watchdog_config(::tensorboard::WatchdogConfig* watchdog_config); + private: + const ::tensorboard::WatchdogConfig& _internal_watchdog_config() const; + ::tensorboard::WatchdogConfig* _internal_mutable_watchdog_config(); + public: + void unsafe_arena_set_allocated_watchdog_config( + ::tensorboard::WatchdogConfig* watchdog_config); + ::tensorboard::WatchdogConfig* unsafe_arena_release_watchdog_config(); + + // .tensorboard.RequestedExitCode exit_code = 3; + bool has_exit_code() const; + private: + bool _internal_has_exit_code() const; + public: + void clear_exit_code(); + const ::tensorboard::RequestedExitCode& exit_code() const; + ::tensorboard::RequestedExitCode* release_exit_code(); + ::tensorboard::RequestedExitCode* mutable_exit_code(); + void set_allocated_exit_code(::tensorboard::RequestedExitCode* exit_code); + private: + const ::tensorboard::RequestedExitCode& _internal_exit_code() const; + ::tensorboard::RequestedExitCode* _internal_mutable_exit_code(); + public: + void unsafe_arena_set_allocated_exit_code( + ::tensorboard::RequestedExitCode* exit_code); + ::tensorboard::RequestedExitCode* unsafe_arena_release_exit_code(); + + // .tensorboard.WorkerShutdownMode shutdown_mode = 1; + void clear_shutdown_mode(); + ::tensorboard::WorkerShutdownMode shutdown_mode() const; + void set_shutdown_mode(::tensorboard::WorkerShutdownMode value); + private: + ::tensorboard::WorkerShutdownMode _internal_shutdown_mode() const; + void _internal_set_shutdown_mode(::tensorboard::WorkerShutdownMode value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.WorkerHeartbeatRequest) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::tensorboard::WatchdogConfig* watchdog_config_; + ::tensorboard::RequestedExitCode* exit_code_; + int shutdown_mode_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// ------------------------------------------------------------------- + +class WorkerHeartbeatResponse PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.WorkerHeartbeatResponse) */ { + public: + inline WorkerHeartbeatResponse() : WorkerHeartbeatResponse(nullptr) {} + virtual ~WorkerHeartbeatResponse(); + + WorkerHeartbeatResponse(const WorkerHeartbeatResponse& from); + WorkerHeartbeatResponse(WorkerHeartbeatResponse&& from) noexcept + : WorkerHeartbeatResponse() { + *this = ::std::move(from); + } + + inline WorkerHeartbeatResponse& operator=(const WorkerHeartbeatResponse& from) { + CopyFrom(from); + return *this; + } + inline WorkerHeartbeatResponse& operator=(WorkerHeartbeatResponse&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const WorkerHeartbeatResponse& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const WorkerHeartbeatResponse* internal_default_instance() { + return reinterpret_cast( + &_WorkerHeartbeatResponse_default_instance_); + } + static constexpr int kIndexInFileMessages = + 8; + + friend void swap(WorkerHeartbeatResponse& a, WorkerHeartbeatResponse& b) { + a.Swap(&b); + } + inline void Swap(WorkerHeartbeatResponse* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(WorkerHeartbeatResponse* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline WorkerHeartbeatResponse* New() const final { + return CreateMaybeMessage(nullptr); + } + + WorkerHeartbeatResponse* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const WorkerHeartbeatResponse& from); + void MergeFrom(const WorkerHeartbeatResponse& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(WorkerHeartbeatResponse* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.WorkerHeartbeatResponse"; + } + protected: + explicit WorkerHeartbeatResponse(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_event_2eproto); + return ::descriptor_table_event_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kWorkerLogFieldNumber = 2, + kHostnameFieldNumber = 3, + kHealthStatusFieldNumber = 1, + }; + // repeated .tensorboard.Event worker_log = 2; + int worker_log_size() const; + private: + int _internal_worker_log_size() const; + public: + void clear_worker_log(); + ::tensorboard::Event* mutable_worker_log(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Event >* + mutable_worker_log(); + private: + const ::tensorboard::Event& _internal_worker_log(int index) const; + ::tensorboard::Event* _internal_add_worker_log(); + public: + const ::tensorboard::Event& worker_log(int index) const; + ::tensorboard::Event* add_worker_log(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Event >& + worker_log() const; + + // string hostname = 3; + void clear_hostname(); + const std::string& hostname() const; + void set_hostname(const std::string& value); + void set_hostname(std::string&& value); + void set_hostname(const char* value); + void set_hostname(const char* value, size_t size); + std::string* mutable_hostname(); + std::string* release_hostname(); + void set_allocated_hostname(std::string* hostname); + private: + const std::string& _internal_hostname() const; + void _internal_set_hostname(const std::string& value); + std::string* _internal_mutable_hostname(); + public: + + // .tensorboard.WorkerHealth health_status = 1; + void clear_health_status(); + ::tensorboard::WorkerHealth health_status() const; + void set_health_status(::tensorboard::WorkerHealth value); + private: + ::tensorboard::WorkerHealth _internal_health_status() const; + void _internal_set_health_status(::tensorboard::WorkerHealth value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.WorkerHeartbeatResponse) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Event > worker_log_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr hostname_; + int health_status_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_event_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Event + +// double wall_time = 1; +inline void Event::clear_wall_time() { + wall_time_ = 0; +} +inline double Event::_internal_wall_time() const { + return wall_time_; +} +inline double Event::wall_time() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.wall_time) + return _internal_wall_time(); +} +inline void Event::_internal_set_wall_time(double value) { + + wall_time_ = value; +} +inline void Event::set_wall_time(double value) { + _internal_set_wall_time(value); + // @@protoc_insertion_point(field_set:tensorboard.Event.wall_time) +} + +// int64 step = 2; +inline void Event::clear_step() { + step_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Event::_internal_step() const { + return step_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Event::step() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.step) + return _internal_step(); +} +inline void Event::_internal_set_step(::PROTOBUF_NAMESPACE_ID::int64 value) { + + step_ = value; +} +inline void Event::set_step(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_step(value); + // @@protoc_insertion_point(field_set:tensorboard.Event.step) +} + +// string file_version = 3; +inline bool Event::_internal_has_file_version() const { + return what_case() == kFileVersion; +} +inline void Event::set_has_file_version() { + _oneof_case_[0] = kFileVersion; +} +inline void Event::clear_file_version() { + if (_internal_has_file_version()) { + what_.file_version_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_what(); + } +} +inline const std::string& Event::file_version() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.file_version) + return _internal_file_version(); +} +inline void Event::set_file_version(const std::string& value) { + _internal_set_file_version(value); + // @@protoc_insertion_point(field_set:tensorboard.Event.file_version) +} +inline std::string* Event::mutable_file_version() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.file_version) + return _internal_mutable_file_version(); +} +inline const std::string& Event::_internal_file_version() const { + if (_internal_has_file_version()) { + return what_.file_version_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void Event::_internal_set_file_version(const std::string& value) { + if (!_internal_has_file_version()) { + clear_what(); + set_has_file_version(); + what_.file_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.file_version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Event::set_file_version(std::string&& value) { + // @@protoc_insertion_point(field_set:tensorboard.Event.file_version) + if (!_internal_has_file_version()) { + clear_what(); + set_has_file_version(); + what_.file_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.file_version_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Event.file_version) +} +inline void Event::set_file_version(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_file_version()) { + clear_what(); + set_has_file_version(); + what_.file_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.file_version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Event.file_version) +} +inline void Event::set_file_version(const char* value, + size_t size) { + if (!_internal_has_file_version()) { + clear_what(); + set_has_file_version(); + what_.file_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.file_version_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Event.file_version) +} +inline std::string* Event::_internal_mutable_file_version() { + if (!_internal_has_file_version()) { + clear_what(); + set_has_file_version(); + what_.file_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return what_.file_version_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Event::release_file_version() { + // @@protoc_insertion_point(field_release:tensorboard.Event.file_version) + if (_internal_has_file_version()) { + clear_has_what(); + return what_.file_version_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void Event::set_allocated_file_version(std::string* file_version) { + if (has_what()) { + clear_what(); + } + if (file_version != nullptr) { + set_has_file_version(); + what_.file_version_.UnsafeSetDefault(file_version); + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); + if (arena != nullptr) { + arena->Own(file_version); + } + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.file_version) +} + +// bytes graph_def = 4; +inline bool Event::_internal_has_graph_def() const { + return what_case() == kGraphDef; +} +inline void Event::set_has_graph_def() { + _oneof_case_[0] = kGraphDef; +} +inline void Event::clear_graph_def() { + if (_internal_has_graph_def()) { + what_.graph_def_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_what(); + } +} +inline const std::string& Event::graph_def() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.graph_def) + return _internal_graph_def(); +} +inline void Event::set_graph_def(const std::string& value) { + _internal_set_graph_def(value); + // @@protoc_insertion_point(field_set:tensorboard.Event.graph_def) +} +inline std::string* Event::mutable_graph_def() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.graph_def) + return _internal_mutable_graph_def(); +} +inline const std::string& Event::_internal_graph_def() const { + if (_internal_has_graph_def()) { + return what_.graph_def_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void Event::_internal_set_graph_def(const std::string& value) { + if (!_internal_has_graph_def()) { + clear_what(); + set_has_graph_def(); + what_.graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.graph_def_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Event::set_graph_def(std::string&& value) { + // @@protoc_insertion_point(field_set:tensorboard.Event.graph_def) + if (!_internal_has_graph_def()) { + clear_what(); + set_has_graph_def(); + what_.graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.graph_def_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Event.graph_def) +} +inline void Event::set_graph_def(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_graph_def()) { + clear_what(); + set_has_graph_def(); + what_.graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.graph_def_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Event.graph_def) +} +inline void Event::set_graph_def(const void* value, + size_t size) { + if (!_internal_has_graph_def()) { + clear_what(); + set_has_graph_def(); + what_.graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.graph_def_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Event.graph_def) +} +inline std::string* Event::_internal_mutable_graph_def() { + if (!_internal_has_graph_def()) { + clear_what(); + set_has_graph_def(); + what_.graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return what_.graph_def_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Event::release_graph_def() { + // @@protoc_insertion_point(field_release:tensorboard.Event.graph_def) + if (_internal_has_graph_def()) { + clear_has_what(); + return what_.graph_def_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void Event::set_allocated_graph_def(std::string* graph_def) { + if (has_what()) { + clear_what(); + } + if (graph_def != nullptr) { + set_has_graph_def(); + what_.graph_def_.UnsafeSetDefault(graph_def); + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); + if (arena != nullptr) { + arena->Own(graph_def); + } + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.graph_def) +} + +// .tensorboard.Summary summary = 5; +inline bool Event::_internal_has_summary() const { + return what_case() == kSummary; +} +inline bool Event::has_summary() const { + return _internal_has_summary(); +} +inline void Event::set_has_summary() { + _oneof_case_[0] = kSummary; +} +inline ::tensorboard::Summary* Event::release_summary() { + // @@protoc_insertion_point(field_release:tensorboard.Event.summary) + if (_internal_has_summary()) { + clear_has_what(); + ::tensorboard::Summary* temp = what_.summary_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.summary_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::Summary& Event::_internal_summary() const { + return _internal_has_summary() + ? *what_.summary_ + : *reinterpret_cast< ::tensorboard::Summary*>(&::tensorboard::_Summary_default_instance_); +} +inline const ::tensorboard::Summary& Event::summary() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.summary) + return _internal_summary(); +} +inline ::tensorboard::Summary* Event::unsafe_arena_release_summary() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Event.summary) + if (_internal_has_summary()) { + clear_has_what(); + ::tensorboard::Summary* temp = what_.summary_; + what_.summary_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_summary(::tensorboard::Summary* summary) { + clear_what(); + if (summary) { + set_has_summary(); + what_.summary_ = summary; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Event.summary) +} +inline ::tensorboard::Summary* Event::_internal_mutable_summary() { + if (!_internal_has_summary()) { + clear_what(); + set_has_summary(); + what_.summary_ = CreateMaybeMessage< ::tensorboard::Summary >(GetArena()); + } + return what_.summary_; +} +inline ::tensorboard::Summary* Event::mutable_summary() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.summary) + return _internal_mutable_summary(); +} + +// .tensorboard.LogMessage log_message = 6 [deprecated = true]; +inline bool Event::_internal_has_log_message() const { + return what_case() == kLogMessage; +} +inline bool Event::has_log_message() const { + return _internal_has_log_message(); +} +inline void Event::set_has_log_message() { + _oneof_case_[0] = kLogMessage; +} +inline void Event::clear_log_message() { + if (_internal_has_log_message()) { + if (GetArena() == nullptr) { + delete what_.log_message_; + } + clear_has_what(); + } +} +inline ::tensorboard::LogMessage* Event::release_log_message() { + // @@protoc_insertion_point(field_release:tensorboard.Event.log_message) + if (_internal_has_log_message()) { + clear_has_what(); + ::tensorboard::LogMessage* temp = what_.log_message_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.log_message_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::LogMessage& Event::_internal_log_message() const { + return _internal_has_log_message() + ? *what_.log_message_ + : *reinterpret_cast< ::tensorboard::LogMessage*>(&::tensorboard::_LogMessage_default_instance_); +} +inline const ::tensorboard::LogMessage& Event::log_message() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.log_message) + return _internal_log_message(); +} +inline ::tensorboard::LogMessage* Event::unsafe_arena_release_log_message() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Event.log_message) + if (_internal_has_log_message()) { + clear_has_what(); + ::tensorboard::LogMessage* temp = what_.log_message_; + what_.log_message_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_log_message(::tensorboard::LogMessage* log_message) { + clear_what(); + if (log_message) { + set_has_log_message(); + what_.log_message_ = log_message; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Event.log_message) +} +inline ::tensorboard::LogMessage* Event::_internal_mutable_log_message() { + if (!_internal_has_log_message()) { + clear_what(); + set_has_log_message(); + what_.log_message_ = CreateMaybeMessage< ::tensorboard::LogMessage >(GetArena()); + } + return what_.log_message_; +} +inline ::tensorboard::LogMessage* Event::mutable_log_message() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.log_message) + return _internal_mutable_log_message(); +} + +// .tensorboard.SessionLog session_log = 7; +inline bool Event::_internal_has_session_log() const { + return what_case() == kSessionLog; +} +inline bool Event::has_session_log() const { + return _internal_has_session_log(); +} +inline void Event::set_has_session_log() { + _oneof_case_[0] = kSessionLog; +} +inline void Event::clear_session_log() { + if (_internal_has_session_log()) { + if (GetArena() == nullptr) { + delete what_.session_log_; + } + clear_has_what(); + } +} +inline ::tensorboard::SessionLog* Event::release_session_log() { + // @@protoc_insertion_point(field_release:tensorboard.Event.session_log) + if (_internal_has_session_log()) { + clear_has_what(); + ::tensorboard::SessionLog* temp = what_.session_log_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.session_log_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::SessionLog& Event::_internal_session_log() const { + return _internal_has_session_log() + ? *what_.session_log_ + : *reinterpret_cast< ::tensorboard::SessionLog*>(&::tensorboard::_SessionLog_default_instance_); +} +inline const ::tensorboard::SessionLog& Event::session_log() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.session_log) + return _internal_session_log(); +} +inline ::tensorboard::SessionLog* Event::unsafe_arena_release_session_log() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Event.session_log) + if (_internal_has_session_log()) { + clear_has_what(); + ::tensorboard::SessionLog* temp = what_.session_log_; + what_.session_log_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_session_log(::tensorboard::SessionLog* session_log) { + clear_what(); + if (session_log) { + set_has_session_log(); + what_.session_log_ = session_log; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Event.session_log) +} +inline ::tensorboard::SessionLog* Event::_internal_mutable_session_log() { + if (!_internal_has_session_log()) { + clear_what(); + set_has_session_log(); + what_.session_log_ = CreateMaybeMessage< ::tensorboard::SessionLog >(GetArena()); + } + return what_.session_log_; +} +inline ::tensorboard::SessionLog* Event::mutable_session_log() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.session_log) + return _internal_mutable_session_log(); +} + +// .tensorboard.TaggedRunMetadata tagged_run_metadata = 8; +inline bool Event::_internal_has_tagged_run_metadata() const { + return what_case() == kTaggedRunMetadata; +} +inline bool Event::has_tagged_run_metadata() const { + return _internal_has_tagged_run_metadata(); +} +inline void Event::set_has_tagged_run_metadata() { + _oneof_case_[0] = kTaggedRunMetadata; +} +inline void Event::clear_tagged_run_metadata() { + if (_internal_has_tagged_run_metadata()) { + if (GetArena() == nullptr) { + delete what_.tagged_run_metadata_; + } + clear_has_what(); + } +} +inline ::tensorboard::TaggedRunMetadata* Event::release_tagged_run_metadata() { + // @@protoc_insertion_point(field_release:tensorboard.Event.tagged_run_metadata) + if (_internal_has_tagged_run_metadata()) { + clear_has_what(); + ::tensorboard::TaggedRunMetadata* temp = what_.tagged_run_metadata_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.tagged_run_metadata_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::TaggedRunMetadata& Event::_internal_tagged_run_metadata() const { + return _internal_has_tagged_run_metadata() + ? *what_.tagged_run_metadata_ + : *reinterpret_cast< ::tensorboard::TaggedRunMetadata*>(&::tensorboard::_TaggedRunMetadata_default_instance_); +} +inline const ::tensorboard::TaggedRunMetadata& Event::tagged_run_metadata() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.tagged_run_metadata) + return _internal_tagged_run_metadata(); +} +inline ::tensorboard::TaggedRunMetadata* Event::unsafe_arena_release_tagged_run_metadata() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Event.tagged_run_metadata) + if (_internal_has_tagged_run_metadata()) { + clear_has_what(); + ::tensorboard::TaggedRunMetadata* temp = what_.tagged_run_metadata_; + what_.tagged_run_metadata_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_tagged_run_metadata(::tensorboard::TaggedRunMetadata* tagged_run_metadata) { + clear_what(); + if (tagged_run_metadata) { + set_has_tagged_run_metadata(); + what_.tagged_run_metadata_ = tagged_run_metadata; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Event.tagged_run_metadata) +} +inline ::tensorboard::TaggedRunMetadata* Event::_internal_mutable_tagged_run_metadata() { + if (!_internal_has_tagged_run_metadata()) { + clear_what(); + set_has_tagged_run_metadata(); + what_.tagged_run_metadata_ = CreateMaybeMessage< ::tensorboard::TaggedRunMetadata >(GetArena()); + } + return what_.tagged_run_metadata_; +} +inline ::tensorboard::TaggedRunMetadata* Event::mutable_tagged_run_metadata() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.tagged_run_metadata) + return _internal_mutable_tagged_run_metadata(); +} + +// bytes meta_graph_def = 9; +inline bool Event::_internal_has_meta_graph_def() const { + return what_case() == kMetaGraphDef; +} +inline void Event::set_has_meta_graph_def() { + _oneof_case_[0] = kMetaGraphDef; +} +inline void Event::clear_meta_graph_def() { + if (_internal_has_meta_graph_def()) { + what_.meta_graph_def_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_what(); + } +} +inline const std::string& Event::meta_graph_def() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.meta_graph_def) + return _internal_meta_graph_def(); +} +inline void Event::set_meta_graph_def(const std::string& value) { + _internal_set_meta_graph_def(value); + // @@protoc_insertion_point(field_set:tensorboard.Event.meta_graph_def) +} +inline std::string* Event::mutable_meta_graph_def() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.meta_graph_def) + return _internal_mutable_meta_graph_def(); +} +inline const std::string& Event::_internal_meta_graph_def() const { + if (_internal_has_meta_graph_def()) { + return what_.meta_graph_def_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void Event::_internal_set_meta_graph_def(const std::string& value) { + if (!_internal_has_meta_graph_def()) { + clear_what(); + set_has_meta_graph_def(); + what_.meta_graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.meta_graph_def_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Event::set_meta_graph_def(std::string&& value) { + // @@protoc_insertion_point(field_set:tensorboard.Event.meta_graph_def) + if (!_internal_has_meta_graph_def()) { + clear_what(); + set_has_meta_graph_def(); + what_.meta_graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.meta_graph_def_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Event.meta_graph_def) +} +inline void Event::set_meta_graph_def(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_meta_graph_def()) { + clear_what(); + set_has_meta_graph_def(); + what_.meta_graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.meta_graph_def_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Event.meta_graph_def) +} +inline void Event::set_meta_graph_def(const void* value, + size_t size) { + if (!_internal_has_meta_graph_def()) { + clear_what(); + set_has_meta_graph_def(); + what_.meta_graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.meta_graph_def_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Event.meta_graph_def) +} +inline std::string* Event::_internal_mutable_meta_graph_def() { + if (!_internal_has_meta_graph_def()) { + clear_what(); + set_has_meta_graph_def(); + what_.meta_graph_def_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return what_.meta_graph_def_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Event::release_meta_graph_def() { + // @@protoc_insertion_point(field_release:tensorboard.Event.meta_graph_def) + if (_internal_has_meta_graph_def()) { + clear_has_what(); + return what_.meta_graph_def_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void Event::set_allocated_meta_graph_def(std::string* meta_graph_def) { + if (has_what()) { + clear_what(); + } + if (meta_graph_def != nullptr) { + set_has_meta_graph_def(); + what_.meta_graph_def_.UnsafeSetDefault(meta_graph_def); + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); + if (arena != nullptr) { + arena->Own(meta_graph_def); + } + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.meta_graph_def) +} + +// .tensorboard.SourceMetadata source_metadata = 10; +inline bool Event::_internal_has_source_metadata() const { + return this != internal_default_instance() && source_metadata_ != nullptr; +} +inline bool Event::has_source_metadata() const { + return _internal_has_source_metadata(); +} +inline void Event::clear_source_metadata() { + if (GetArena() == nullptr && source_metadata_ != nullptr) { + delete source_metadata_; + } + source_metadata_ = nullptr; +} +inline const ::tensorboard::SourceMetadata& Event::_internal_source_metadata() const { + const ::tensorboard::SourceMetadata* p = source_metadata_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_SourceMetadata_default_instance_); +} +inline const ::tensorboard::SourceMetadata& Event::source_metadata() const { + // @@protoc_insertion_point(field_get:tensorboard.Event.source_metadata) + return _internal_source_metadata(); +} +inline void Event::unsafe_arena_set_allocated_source_metadata( + ::tensorboard::SourceMetadata* source_metadata) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(source_metadata_); + } + source_metadata_ = source_metadata; + if (source_metadata) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Event.source_metadata) +} +inline ::tensorboard::SourceMetadata* Event::release_source_metadata() { + + ::tensorboard::SourceMetadata* temp = source_metadata_; + source_metadata_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::SourceMetadata* Event::unsafe_arena_release_source_metadata() { + // @@protoc_insertion_point(field_release:tensorboard.Event.source_metadata) + + ::tensorboard::SourceMetadata* temp = source_metadata_; + source_metadata_ = nullptr; + return temp; +} +inline ::tensorboard::SourceMetadata* Event::_internal_mutable_source_metadata() { + + if (source_metadata_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::SourceMetadata>(GetArena()); + source_metadata_ = p; + } + return source_metadata_; +} +inline ::tensorboard::SourceMetadata* Event::mutable_source_metadata() { + // @@protoc_insertion_point(field_mutable:tensorboard.Event.source_metadata) + return _internal_mutable_source_metadata(); +} +inline void Event::set_allocated_source_metadata(::tensorboard::SourceMetadata* source_metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete source_metadata_; + } + if (source_metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(source_metadata); + if (message_arena != submessage_arena) { + source_metadata = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, source_metadata, submessage_arena); + } + + } else { + + } + source_metadata_ = source_metadata; + // @@protoc_insertion_point(field_set_allocated:tensorboard.Event.source_metadata) +} + +inline bool Event::has_what() const { + return what_case() != WHAT_NOT_SET; +} +inline void Event::clear_has_what() { + _oneof_case_[0] = WHAT_NOT_SET; +} +inline Event::WhatCase Event::what_case() const { + return Event::WhatCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// SourceMetadata + +// string writer = 1; +inline void SourceMetadata::clear_writer() { + writer_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SourceMetadata::writer() const { + // @@protoc_insertion_point(field_get:tensorboard.SourceMetadata.writer) + return _internal_writer(); +} +inline void SourceMetadata::set_writer(const std::string& value) { + _internal_set_writer(value); + // @@protoc_insertion_point(field_set:tensorboard.SourceMetadata.writer) +} +inline std::string* SourceMetadata::mutable_writer() { + // @@protoc_insertion_point(field_mutable:tensorboard.SourceMetadata.writer) + return _internal_mutable_writer(); +} +inline const std::string& SourceMetadata::_internal_writer() const { + return writer_.Get(); +} +inline void SourceMetadata::_internal_set_writer(const std::string& value) { + + writer_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SourceMetadata::set_writer(std::string&& value) { + + writer_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SourceMetadata.writer) +} +inline void SourceMetadata::set_writer(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + writer_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SourceMetadata.writer) +} +inline void SourceMetadata::set_writer(const char* value, + size_t size) { + + writer_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SourceMetadata.writer) +} +inline std::string* SourceMetadata::_internal_mutable_writer() { + + return writer_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SourceMetadata::release_writer() { + // @@protoc_insertion_point(field_release:tensorboard.SourceMetadata.writer) + return writer_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SourceMetadata::set_allocated_writer(std::string* writer) { + if (writer != nullptr) { + + } else { + + } + writer_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), writer, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SourceMetadata.writer) +} + +// ------------------------------------------------------------------- + +// LogMessage + +// .tensorboard.LogMessage.Level level = 1; +inline void LogMessage::clear_level() { + level_ = 0; +} +inline ::tensorboard::LogMessage_Level LogMessage::_internal_level() const { + return static_cast< ::tensorboard::LogMessage_Level >(level_); +} +inline ::tensorboard::LogMessage_Level LogMessage::level() const { + // @@protoc_insertion_point(field_get:tensorboard.LogMessage.level) + return _internal_level(); +} +inline void LogMessage::_internal_set_level(::tensorboard::LogMessage_Level value) { + + level_ = value; +} +inline void LogMessage::set_level(::tensorboard::LogMessage_Level value) { + _internal_set_level(value); + // @@protoc_insertion_point(field_set:tensorboard.LogMessage.level) +} + +// string message = 2; +inline void LogMessage::clear_message() { + message_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& LogMessage::message() const { + // @@protoc_insertion_point(field_get:tensorboard.LogMessage.message) + return _internal_message(); +} +inline void LogMessage::set_message(const std::string& value) { + _internal_set_message(value); + // @@protoc_insertion_point(field_set:tensorboard.LogMessage.message) +} +inline std::string* LogMessage::mutable_message() { + // @@protoc_insertion_point(field_mutable:tensorboard.LogMessage.message) + return _internal_mutable_message(); +} +inline const std::string& LogMessage::_internal_message() const { + return message_.Get(); +} +inline void LogMessage::_internal_set_message(const std::string& value) { + + message_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void LogMessage::set_message(std::string&& value) { + + message_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.LogMessage.message) +} +inline void LogMessage::set_message(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + message_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.LogMessage.message) +} +inline void LogMessage::set_message(const char* value, + size_t size) { + + message_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.LogMessage.message) +} +inline std::string* LogMessage::_internal_mutable_message() { + + return message_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* LogMessage::release_message() { + // @@protoc_insertion_point(field_release:tensorboard.LogMessage.message) + return message_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void LogMessage::set_allocated_message(std::string* message) { + if (message != nullptr) { + + } else { + + } + message_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), message, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.LogMessage.message) +} + +// ------------------------------------------------------------------- + +// SessionLog + +// .tensorboard.SessionLog.SessionStatus status = 1; +inline void SessionLog::clear_status() { + status_ = 0; +} +inline ::tensorboard::SessionLog_SessionStatus SessionLog::_internal_status() const { + return static_cast< ::tensorboard::SessionLog_SessionStatus >(status_); +} +inline ::tensorboard::SessionLog_SessionStatus SessionLog::status() const { + // @@protoc_insertion_point(field_get:tensorboard.SessionLog.status) + return _internal_status(); +} +inline void SessionLog::_internal_set_status(::tensorboard::SessionLog_SessionStatus value) { + + status_ = value; +} +inline void SessionLog::set_status(::tensorboard::SessionLog_SessionStatus value) { + _internal_set_status(value); + // @@protoc_insertion_point(field_set:tensorboard.SessionLog.status) +} + +// string checkpoint_path = 2; +inline void SessionLog::clear_checkpoint_path() { + checkpoint_path_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SessionLog::checkpoint_path() const { + // @@protoc_insertion_point(field_get:tensorboard.SessionLog.checkpoint_path) + return _internal_checkpoint_path(); +} +inline void SessionLog::set_checkpoint_path(const std::string& value) { + _internal_set_checkpoint_path(value); + // @@protoc_insertion_point(field_set:tensorboard.SessionLog.checkpoint_path) +} +inline std::string* SessionLog::mutable_checkpoint_path() { + // @@protoc_insertion_point(field_mutable:tensorboard.SessionLog.checkpoint_path) + return _internal_mutable_checkpoint_path(); +} +inline const std::string& SessionLog::_internal_checkpoint_path() const { + return checkpoint_path_.Get(); +} +inline void SessionLog::_internal_set_checkpoint_path(const std::string& value) { + + checkpoint_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SessionLog::set_checkpoint_path(std::string&& value) { + + checkpoint_path_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SessionLog.checkpoint_path) +} +inline void SessionLog::set_checkpoint_path(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + checkpoint_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SessionLog.checkpoint_path) +} +inline void SessionLog::set_checkpoint_path(const char* value, + size_t size) { + + checkpoint_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SessionLog.checkpoint_path) +} +inline std::string* SessionLog::_internal_mutable_checkpoint_path() { + + return checkpoint_path_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SessionLog::release_checkpoint_path() { + // @@protoc_insertion_point(field_release:tensorboard.SessionLog.checkpoint_path) + return checkpoint_path_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SessionLog::set_allocated_checkpoint_path(std::string* checkpoint_path) { + if (checkpoint_path != nullptr) { + + } else { + + } + checkpoint_path_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), checkpoint_path, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SessionLog.checkpoint_path) +} + +// string msg = 3; +inline void SessionLog::clear_msg() { + msg_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SessionLog::msg() const { + // @@protoc_insertion_point(field_get:tensorboard.SessionLog.msg) + return _internal_msg(); +} +inline void SessionLog::set_msg(const std::string& value) { + _internal_set_msg(value); + // @@protoc_insertion_point(field_set:tensorboard.SessionLog.msg) +} +inline std::string* SessionLog::mutable_msg() { + // @@protoc_insertion_point(field_mutable:tensorboard.SessionLog.msg) + return _internal_mutable_msg(); +} +inline const std::string& SessionLog::_internal_msg() const { + return msg_.Get(); +} +inline void SessionLog::_internal_set_msg(const std::string& value) { + + msg_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SessionLog::set_msg(std::string&& value) { + + msg_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SessionLog.msg) +} +inline void SessionLog::set_msg(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + msg_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SessionLog.msg) +} +inline void SessionLog::set_msg(const char* value, + size_t size) { + + msg_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SessionLog.msg) +} +inline std::string* SessionLog::_internal_mutable_msg() { + + return msg_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SessionLog::release_msg() { + // @@protoc_insertion_point(field_release:tensorboard.SessionLog.msg) + return msg_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SessionLog::set_allocated_msg(std::string* msg) { + if (msg != nullptr) { + + } else { + + } + msg_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), msg, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SessionLog.msg) +} + +// ------------------------------------------------------------------- + +// TaggedRunMetadata + +// string tag = 1; +inline void TaggedRunMetadata::clear_tag() { + tag_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& TaggedRunMetadata::tag() const { + // @@protoc_insertion_point(field_get:tensorboard.TaggedRunMetadata.tag) + return _internal_tag(); +} +inline void TaggedRunMetadata::set_tag(const std::string& value) { + _internal_set_tag(value); + // @@protoc_insertion_point(field_set:tensorboard.TaggedRunMetadata.tag) +} +inline std::string* TaggedRunMetadata::mutable_tag() { + // @@protoc_insertion_point(field_mutable:tensorboard.TaggedRunMetadata.tag) + return _internal_mutable_tag(); +} +inline const std::string& TaggedRunMetadata::_internal_tag() const { + return tag_.Get(); +} +inline void TaggedRunMetadata::_internal_set_tag(const std::string& value) { + + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TaggedRunMetadata::set_tag(std::string&& value) { + + tag_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.TaggedRunMetadata.tag) +} +inline void TaggedRunMetadata::set_tag(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.TaggedRunMetadata.tag) +} +inline void TaggedRunMetadata::set_tag(const char* value, + size_t size) { + + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.TaggedRunMetadata.tag) +} +inline std::string* TaggedRunMetadata::_internal_mutable_tag() { + + return tag_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TaggedRunMetadata::release_tag() { + // @@protoc_insertion_point(field_release:tensorboard.TaggedRunMetadata.tag) + return tag_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TaggedRunMetadata::set_allocated_tag(std::string* tag) { + if (tag != nullptr) { + + } else { + + } + tag_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), tag, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.TaggedRunMetadata.tag) +} + +// bytes run_metadata = 2; +inline void TaggedRunMetadata::clear_run_metadata() { + run_metadata_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& TaggedRunMetadata::run_metadata() const { + // @@protoc_insertion_point(field_get:tensorboard.TaggedRunMetadata.run_metadata) + return _internal_run_metadata(); +} +inline void TaggedRunMetadata::set_run_metadata(const std::string& value) { + _internal_set_run_metadata(value); + // @@protoc_insertion_point(field_set:tensorboard.TaggedRunMetadata.run_metadata) +} +inline std::string* TaggedRunMetadata::mutable_run_metadata() { + // @@protoc_insertion_point(field_mutable:tensorboard.TaggedRunMetadata.run_metadata) + return _internal_mutable_run_metadata(); +} +inline const std::string& TaggedRunMetadata::_internal_run_metadata() const { + return run_metadata_.Get(); +} +inline void TaggedRunMetadata::_internal_set_run_metadata(const std::string& value) { + + run_metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TaggedRunMetadata::set_run_metadata(std::string&& value) { + + run_metadata_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.TaggedRunMetadata.run_metadata) +} +inline void TaggedRunMetadata::set_run_metadata(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + run_metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.TaggedRunMetadata.run_metadata) +} +inline void TaggedRunMetadata::set_run_metadata(const void* value, + size_t size) { + + run_metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.TaggedRunMetadata.run_metadata) +} +inline std::string* TaggedRunMetadata::_internal_mutable_run_metadata() { + + return run_metadata_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TaggedRunMetadata::release_run_metadata() { + // @@protoc_insertion_point(field_release:tensorboard.TaggedRunMetadata.run_metadata) + return run_metadata_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TaggedRunMetadata::set_allocated_run_metadata(std::string* run_metadata) { + if (run_metadata != nullptr) { + + } else { + + } + run_metadata_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), run_metadata, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.TaggedRunMetadata.run_metadata) +} + +// ------------------------------------------------------------------- + +// WatchdogConfig + +// int64 timeout_ms = 1; +inline void WatchdogConfig::clear_timeout_ms() { + timeout_ms_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 WatchdogConfig::_internal_timeout_ms() const { + return timeout_ms_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 WatchdogConfig::timeout_ms() const { + // @@protoc_insertion_point(field_get:tensorboard.WatchdogConfig.timeout_ms) + return _internal_timeout_ms(); +} +inline void WatchdogConfig::_internal_set_timeout_ms(::PROTOBUF_NAMESPACE_ID::int64 value) { + + timeout_ms_ = value; +} +inline void WatchdogConfig::set_timeout_ms(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_timeout_ms(value); + // @@protoc_insertion_point(field_set:tensorboard.WatchdogConfig.timeout_ms) +} + +// ------------------------------------------------------------------- + +// RequestedExitCode + +// int32 exit_code = 1; +inline void RequestedExitCode::clear_exit_code() { + exit_code_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 RequestedExitCode::_internal_exit_code() const { + return exit_code_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 RequestedExitCode::exit_code() const { + // @@protoc_insertion_point(field_get:tensorboard.RequestedExitCode.exit_code) + return _internal_exit_code(); +} +inline void RequestedExitCode::_internal_set_exit_code(::PROTOBUF_NAMESPACE_ID::int32 value) { + + exit_code_ = value; +} +inline void RequestedExitCode::set_exit_code(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_exit_code(value); + // @@protoc_insertion_point(field_set:tensorboard.RequestedExitCode.exit_code) +} + +// ------------------------------------------------------------------- + +// WorkerHeartbeatRequest + +// .tensorboard.WorkerShutdownMode shutdown_mode = 1; +inline void WorkerHeartbeatRequest::clear_shutdown_mode() { + shutdown_mode_ = 0; +} +inline ::tensorboard::WorkerShutdownMode WorkerHeartbeatRequest::_internal_shutdown_mode() const { + return static_cast< ::tensorboard::WorkerShutdownMode >(shutdown_mode_); +} +inline ::tensorboard::WorkerShutdownMode WorkerHeartbeatRequest::shutdown_mode() const { + // @@protoc_insertion_point(field_get:tensorboard.WorkerHeartbeatRequest.shutdown_mode) + return _internal_shutdown_mode(); +} +inline void WorkerHeartbeatRequest::_internal_set_shutdown_mode(::tensorboard::WorkerShutdownMode value) { + + shutdown_mode_ = value; +} +inline void WorkerHeartbeatRequest::set_shutdown_mode(::tensorboard::WorkerShutdownMode value) { + _internal_set_shutdown_mode(value); + // @@protoc_insertion_point(field_set:tensorboard.WorkerHeartbeatRequest.shutdown_mode) +} + +// .tensorboard.WatchdogConfig watchdog_config = 2; +inline bool WorkerHeartbeatRequest::_internal_has_watchdog_config() const { + return this != internal_default_instance() && watchdog_config_ != nullptr; +} +inline bool WorkerHeartbeatRequest::has_watchdog_config() const { + return _internal_has_watchdog_config(); +} +inline void WorkerHeartbeatRequest::clear_watchdog_config() { + if (GetArena() == nullptr && watchdog_config_ != nullptr) { + delete watchdog_config_; + } + watchdog_config_ = nullptr; +} +inline const ::tensorboard::WatchdogConfig& WorkerHeartbeatRequest::_internal_watchdog_config() const { + const ::tensorboard::WatchdogConfig* p = watchdog_config_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_WatchdogConfig_default_instance_); +} +inline const ::tensorboard::WatchdogConfig& WorkerHeartbeatRequest::watchdog_config() const { + // @@protoc_insertion_point(field_get:tensorboard.WorkerHeartbeatRequest.watchdog_config) + return _internal_watchdog_config(); +} +inline void WorkerHeartbeatRequest::unsafe_arena_set_allocated_watchdog_config( + ::tensorboard::WatchdogConfig* watchdog_config) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(watchdog_config_); + } + watchdog_config_ = watchdog_config; + if (watchdog_config) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.WorkerHeartbeatRequest.watchdog_config) +} +inline ::tensorboard::WatchdogConfig* WorkerHeartbeatRequest::release_watchdog_config() { + + ::tensorboard::WatchdogConfig* temp = watchdog_config_; + watchdog_config_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::WatchdogConfig* WorkerHeartbeatRequest::unsafe_arena_release_watchdog_config() { + // @@protoc_insertion_point(field_release:tensorboard.WorkerHeartbeatRequest.watchdog_config) + + ::tensorboard::WatchdogConfig* temp = watchdog_config_; + watchdog_config_ = nullptr; + return temp; +} +inline ::tensorboard::WatchdogConfig* WorkerHeartbeatRequest::_internal_mutable_watchdog_config() { + + if (watchdog_config_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::WatchdogConfig>(GetArena()); + watchdog_config_ = p; + } + return watchdog_config_; +} +inline ::tensorboard::WatchdogConfig* WorkerHeartbeatRequest::mutable_watchdog_config() { + // @@protoc_insertion_point(field_mutable:tensorboard.WorkerHeartbeatRequest.watchdog_config) + return _internal_mutable_watchdog_config(); +} +inline void WorkerHeartbeatRequest::set_allocated_watchdog_config(::tensorboard::WatchdogConfig* watchdog_config) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete watchdog_config_; + } + if (watchdog_config) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(watchdog_config); + if (message_arena != submessage_arena) { + watchdog_config = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, watchdog_config, submessage_arena); + } + + } else { + + } + watchdog_config_ = watchdog_config; + // @@protoc_insertion_point(field_set_allocated:tensorboard.WorkerHeartbeatRequest.watchdog_config) +} + +// .tensorboard.RequestedExitCode exit_code = 3; +inline bool WorkerHeartbeatRequest::_internal_has_exit_code() const { + return this != internal_default_instance() && exit_code_ != nullptr; +} +inline bool WorkerHeartbeatRequest::has_exit_code() const { + return _internal_has_exit_code(); +} +inline void WorkerHeartbeatRequest::clear_exit_code() { + if (GetArena() == nullptr && exit_code_ != nullptr) { + delete exit_code_; + } + exit_code_ = nullptr; +} +inline const ::tensorboard::RequestedExitCode& WorkerHeartbeatRequest::_internal_exit_code() const { + const ::tensorboard::RequestedExitCode* p = exit_code_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_RequestedExitCode_default_instance_); +} +inline const ::tensorboard::RequestedExitCode& WorkerHeartbeatRequest::exit_code() const { + // @@protoc_insertion_point(field_get:tensorboard.WorkerHeartbeatRequest.exit_code) + return _internal_exit_code(); +} +inline void WorkerHeartbeatRequest::unsafe_arena_set_allocated_exit_code( + ::tensorboard::RequestedExitCode* exit_code) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(exit_code_); + } + exit_code_ = exit_code; + if (exit_code) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.WorkerHeartbeatRequest.exit_code) +} +inline ::tensorboard::RequestedExitCode* WorkerHeartbeatRequest::release_exit_code() { + + ::tensorboard::RequestedExitCode* temp = exit_code_; + exit_code_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::RequestedExitCode* WorkerHeartbeatRequest::unsafe_arena_release_exit_code() { + // @@protoc_insertion_point(field_release:tensorboard.WorkerHeartbeatRequest.exit_code) + + ::tensorboard::RequestedExitCode* temp = exit_code_; + exit_code_ = nullptr; + return temp; +} +inline ::tensorboard::RequestedExitCode* WorkerHeartbeatRequest::_internal_mutable_exit_code() { + + if (exit_code_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::RequestedExitCode>(GetArena()); + exit_code_ = p; + } + return exit_code_; +} +inline ::tensorboard::RequestedExitCode* WorkerHeartbeatRequest::mutable_exit_code() { + // @@protoc_insertion_point(field_mutable:tensorboard.WorkerHeartbeatRequest.exit_code) + return _internal_mutable_exit_code(); +} +inline void WorkerHeartbeatRequest::set_allocated_exit_code(::tensorboard::RequestedExitCode* exit_code) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete exit_code_; + } + if (exit_code) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(exit_code); + if (message_arena != submessage_arena) { + exit_code = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, exit_code, submessage_arena); + } + + } else { + + } + exit_code_ = exit_code; + // @@protoc_insertion_point(field_set_allocated:tensorboard.WorkerHeartbeatRequest.exit_code) +} + +// ------------------------------------------------------------------- + +// WorkerHeartbeatResponse + +// .tensorboard.WorkerHealth health_status = 1; +inline void WorkerHeartbeatResponse::clear_health_status() { + health_status_ = 0; +} +inline ::tensorboard::WorkerHealth WorkerHeartbeatResponse::_internal_health_status() const { + return static_cast< ::tensorboard::WorkerHealth >(health_status_); +} +inline ::tensorboard::WorkerHealth WorkerHeartbeatResponse::health_status() const { + // @@protoc_insertion_point(field_get:tensorboard.WorkerHeartbeatResponse.health_status) + return _internal_health_status(); +} +inline void WorkerHeartbeatResponse::_internal_set_health_status(::tensorboard::WorkerHealth value) { + + health_status_ = value; +} +inline void WorkerHeartbeatResponse::set_health_status(::tensorboard::WorkerHealth value) { + _internal_set_health_status(value); + // @@protoc_insertion_point(field_set:tensorboard.WorkerHeartbeatResponse.health_status) +} + +// repeated .tensorboard.Event worker_log = 2; +inline int WorkerHeartbeatResponse::_internal_worker_log_size() const { + return worker_log_.size(); +} +inline int WorkerHeartbeatResponse::worker_log_size() const { + return _internal_worker_log_size(); +} +inline void WorkerHeartbeatResponse::clear_worker_log() { + worker_log_.Clear(); +} +inline ::tensorboard::Event* WorkerHeartbeatResponse::mutable_worker_log(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.WorkerHeartbeatResponse.worker_log) + return worker_log_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Event >* +WorkerHeartbeatResponse::mutable_worker_log() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.WorkerHeartbeatResponse.worker_log) + return &worker_log_; +} +inline const ::tensorboard::Event& WorkerHeartbeatResponse::_internal_worker_log(int index) const { + return worker_log_.Get(index); +} +inline const ::tensorboard::Event& WorkerHeartbeatResponse::worker_log(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.WorkerHeartbeatResponse.worker_log) + return _internal_worker_log(index); +} +inline ::tensorboard::Event* WorkerHeartbeatResponse::_internal_add_worker_log() { + return worker_log_.Add(); +} +inline ::tensorboard::Event* WorkerHeartbeatResponse::add_worker_log() { + // @@protoc_insertion_point(field_add:tensorboard.WorkerHeartbeatResponse.worker_log) + return _internal_add_worker_log(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Event >& +WorkerHeartbeatResponse::worker_log() const { + // @@protoc_insertion_point(field_list:tensorboard.WorkerHeartbeatResponse.worker_log) + return worker_log_; +} + +// string hostname = 3; +inline void WorkerHeartbeatResponse::clear_hostname() { + hostname_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& WorkerHeartbeatResponse::hostname() const { + // @@protoc_insertion_point(field_get:tensorboard.WorkerHeartbeatResponse.hostname) + return _internal_hostname(); +} +inline void WorkerHeartbeatResponse::set_hostname(const std::string& value) { + _internal_set_hostname(value); + // @@protoc_insertion_point(field_set:tensorboard.WorkerHeartbeatResponse.hostname) +} +inline std::string* WorkerHeartbeatResponse::mutable_hostname() { + // @@protoc_insertion_point(field_mutable:tensorboard.WorkerHeartbeatResponse.hostname) + return _internal_mutable_hostname(); +} +inline const std::string& WorkerHeartbeatResponse::_internal_hostname() const { + return hostname_.Get(); +} +inline void WorkerHeartbeatResponse::_internal_set_hostname(const std::string& value) { + + hostname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void WorkerHeartbeatResponse::set_hostname(std::string&& value) { + + hostname_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.WorkerHeartbeatResponse.hostname) +} +inline void WorkerHeartbeatResponse::set_hostname(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + hostname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.WorkerHeartbeatResponse.hostname) +} +inline void WorkerHeartbeatResponse::set_hostname(const char* value, + size_t size) { + + hostname_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.WorkerHeartbeatResponse.hostname) +} +inline std::string* WorkerHeartbeatResponse::_internal_mutable_hostname() { + + return hostname_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* WorkerHeartbeatResponse::release_hostname() { + // @@protoc_insertion_point(field_release:tensorboard.WorkerHeartbeatResponse.hostname) + return hostname_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void WorkerHeartbeatResponse::set_allocated_hostname(std::string* hostname) { + if (hostname != nullptr) { + + } else { + + } + hostname_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), hostname, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.WorkerHeartbeatResponse.hostname) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< ::tensorboard::LogMessage_Level> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::tensorboard::LogMessage_Level>() { + return ::tensorboard::LogMessage_Level_descriptor(); +} +template <> struct is_proto_enum< ::tensorboard::SessionLog_SessionStatus> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::tensorboard::SessionLog_SessionStatus>() { + return ::tensorboard::SessionLog_SessionStatus_descriptor(); +} +template <> struct is_proto_enum< ::tensorboard::WorkerHealth> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::tensorboard::WorkerHealth>() { + return ::tensorboard::WorkerHealth_descriptor(); +} +template <> struct is_proto_enum< ::tensorboard::WorkerShutdownMode> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::tensorboard::WorkerShutdownMode>() { + return ::tensorboard::WorkerShutdownMode_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_event_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/event.proto b/plugins/mindstudio-insight-plugins/proto/event.proto new file mode 100644 index 0000000000000000000000000000000000000000..f135df01e22caeaafe7fc9e9c32efb98f017fc40 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/event.proto @@ -0,0 +1,141 @@ +syntax = "proto3"; + +package tensorboard; + +import "summary.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "EventProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/util/event_go_proto"; + +// Protocol buffer representing an event that happened during +// the execution of a Brain model. +message Event { + // Timestamp of the event. + double wall_time = 1; + + // Global step of the event. + int64 step = 2; + + oneof what { + // An event file was started, with the specified version. + // This is use to identify the contents of the record IO files + // easily. Current version is "brain.Event:2". All versions + // start with "brain.Event:". + string file_version = 3; + // An encoded version of a GraphDef. + bytes graph_def = 4; + // A summary was generated. + Summary summary = 5; + // The user output a log message. This was theoretically used by the defunct + // tensorboard_logging module, which has since been removed; this field is + // now deprecated and should not be used. + LogMessage log_message = 6 [deprecated = true]; + // The state of the session which can be used for restarting after crashes. + SessionLog session_log = 7; + // The metadata returned by running a session.run() call. + TaggedRunMetadata tagged_run_metadata = 8; + // An encoded version of a MetaGraphDef. + bytes meta_graph_def = 9; + } + + // Information of the source that writes the events, this is only logged in + // the very first event along with the `file_version` field. + SourceMetadata source_metadata = 10; +} + +// Holds the information of the source that writes the events. +message SourceMetadata { + // Low level name of the summary writer, such as + // `tensorflow.core.util.events_writer`. + string writer = 1; +} + +// Protocol buffer used for logging messages to the events file. +// +// This was theoretically used by the defunct tensorboard_logging module, which +// has been removed; this message is now deprecated and should not be used. +message LogMessage { + option deprecated = true; + enum Level { + option deprecated = true; + UNKNOWN = 0; + // Note: The logging level 10 cannot be named DEBUG. Some software + // projects compile their C/C++ code with -DDEBUG in debug builds. So the + // C++ code generated from this file should not have an identifier named + // DEBUG. + DEBUGGING = 10; + INFO = 20; + WARN = 30; + ERROR = 40; + FATAL = 50; + } + Level level = 1; + string message = 2; +} + +// Protocol buffer used for logging session state. +message SessionLog { + enum SessionStatus { + STATUS_UNSPECIFIED = 0; + START = 1; + STOP = 2; + CHECKPOINT = 3; + } + + SessionStatus status = 1; + // This checkpoint_path contains both the path and filename. + string checkpoint_path = 2; + string msg = 3; +} + +// For logging the metadata output for a single session.run() call. +message TaggedRunMetadata { + // Tag name associated with this metadata. + string tag = 1; + // Byte-encoded version of the `RunMetadata` proto in order to allow lazy + // deserialization. + bytes run_metadata = 2; +} + +// Worker heartbeat messages. Support for these operations is currently +// internal and expected to change. + +// Current health status of a worker. +enum WorkerHealth { + OK = 0; // By default a worker is healthy. + RECEIVED_SHUTDOWN_SIGNAL = 1; + INTERNAL_ERROR = 2; + SHUTTING_DOWN = 3; // Worker has been instructed to shutdown after a timeout. +} + +// Indicates the behavior of the worker when an internal error or shutdown +// signal is received. +enum WorkerShutdownMode { + DEFAULT = 0; + NOT_CONFIGURED = 1; + WAIT_FOR_COORDINATOR = 2; + SHUTDOWN_AFTER_TIMEOUT = 3; +} + +message WatchdogConfig { + int64 timeout_ms = 1; +} + +message RequestedExitCode { + int32 exit_code = 1; +} + +message WorkerHeartbeatRequest { + WorkerShutdownMode shutdown_mode = 1; + WatchdogConfig watchdog_config = 2; + RequestedExitCode exit_code = 3; +} + +message WorkerHeartbeatResponse { + WorkerHealth health_status = 1; + repeated Event worker_log = 2; + string hostname = 3; +} diff --git a/plugins/mindstudio-insight-plugins/proto/histogram.pb.cc b/plugins/mindstudio-insight-plugins/proto/histogram.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..31fbbd4667f010b19b84ec7f666975f71a331ed6 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/histogram.pb.cc @@ -0,0 +1,458 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: histogram.proto + +#include "histogram.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +namespace tensorboard { +class HistogramProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _HistogramProto_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_HistogramProto_histogram_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_HistogramProto_default_instance_; + new (ptr) ::tensorboard::HistogramProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::HistogramProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_HistogramProto_histogram_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_HistogramProto_histogram_2eproto}, {}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_histogram_2eproto[1]; +static constexpr ::PROTOBUF_NAMESPACE_ID::EnumDescriptor const** file_level_enum_descriptors_histogram_2eproto = nullptr; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_histogram_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_histogram_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, min_), + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, max_), + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, num_), + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, sum_), + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, sum_squares_), + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, bucket_limit_), + PROTOBUF_FIELD_OFFSET(::tensorboard::HistogramProto, bucket_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::HistogramProto)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_HistogramProto_default_instance_), +}; + +const char descriptor_table_protodef_histogram_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\017histogram.proto\022\013tensorboard\"\207\001\n\016Histo" + "gramProto\022\013\n\003min\030\001 \001(\001\022\013\n\003max\030\002 \001(\001\022\013\n\003n" + "um\030\003 \001(\001\022\013\n\003sum\030\004 \001(\001\022\023\n\013sum_squares\030\005 \001" + "(\001\022\030\n\014bucket_limit\030\006 \003(\001B\002\020\001\022\022\n\006bucket\030\007" + " \003(\001B\002\020\001B\\\n\030org.tensorflow.frameworkP\001Z;" + "github.com/google/tsl/tsl/go/core/protob" + "uf/summary_go_proto\370\001\001b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_histogram_2eproto_deps[1] = { +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_histogram_2eproto_sccs[1] = { + &scc_info_HistogramProto_histogram_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_histogram_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_histogram_2eproto = { + false, false, descriptor_table_protodef_histogram_2eproto, "histogram.proto", 270, + &descriptor_table_histogram_2eproto_once, descriptor_table_histogram_2eproto_sccs, descriptor_table_histogram_2eproto_deps, 1, 0, + schemas, file_default_instances, TableStruct_histogram_2eproto::offsets, + file_level_metadata_histogram_2eproto, 1, file_level_enum_descriptors_histogram_2eproto, file_level_service_descriptors_histogram_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_histogram_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_histogram_2eproto)), true); +namespace tensorboard { + +// =================================================================== + +void HistogramProto::InitAsDefaultInstance() { +} +class HistogramProto::_Internal { + public: +}; + +HistogramProto::HistogramProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + bucket_limit_(arena), + bucket_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.HistogramProto) +} +HistogramProto::HistogramProto(const HistogramProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + bucket_limit_(from.bucket_limit_), + bucket_(from.bucket_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::memcpy(&min_, &from.min_, + static_cast(reinterpret_cast(&sum_squares_) - + reinterpret_cast(&min_)) + sizeof(sum_squares_)); + // @@protoc_insertion_point(copy_constructor:tensorboard.HistogramProto) +} + +void HistogramProto::SharedCtor() { + ::memset(&min_, 0, static_cast( + reinterpret_cast(&sum_squares_) - + reinterpret_cast(&min_)) + sizeof(sum_squares_)); +} + +HistogramProto::~HistogramProto() { + // @@protoc_insertion_point(destructor:tensorboard.HistogramProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void HistogramProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void HistogramProto::ArenaDtor(void* object) { + HistogramProto* _this = reinterpret_cast< HistogramProto* >(object); + (void)_this; +} +void HistogramProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void HistogramProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const HistogramProto& HistogramProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_HistogramProto_histogram_2eproto.base); + return *internal_default_instance(); +} + + +void HistogramProto::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.HistogramProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + bucket_limit_.Clear(); + bucket_.Clear(); + ::memset(&min_, 0, static_cast( + reinterpret_cast(&sum_squares_) - + reinterpret_cast(&min_)) + sizeof(sum_squares_)); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* HistogramProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // double min = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 9)) { + min_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // double max = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 17)) { + max_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // double num = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 25)) { + num_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // double sum = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 33)) { + sum_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // double sum_squares = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 41)) { + sum_squares_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // repeated double bucket_limit = 6 [packed = true]; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_bucket_limit(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 49) { + _internal_add_bucket_limit(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // repeated double bucket = 7 [packed = true]; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_bucket(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 57) { + _internal_add_bucket(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* HistogramProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.HistogramProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // double min = 1; + if (!(this->min() <= 0 && this->min() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(1, this->_internal_min(), target); + } + + // double max = 2; + if (!(this->max() <= 0 && this->max() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(2, this->_internal_max(), target); + } + + // double num = 3; + if (!(this->num() <= 0 && this->num() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(3, this->_internal_num(), target); + } + + // double sum = 4; + if (!(this->sum() <= 0 && this->sum() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(4, this->_internal_sum(), target); + } + + // double sum_squares = 5; + if (!(this->sum_squares() <= 0 && this->sum_squares() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(5, this->_internal_sum_squares(), target); + } + + // repeated double bucket_limit = 6 [packed = true]; + if (this->_internal_bucket_limit_size() > 0) { + target = stream->WriteFixedPacked(6, _internal_bucket_limit(), target); + } + + // repeated double bucket = 7 [packed = true]; + if (this->_internal_bucket_size() > 0) { + target = stream->WriteFixedPacked(7, _internal_bucket(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.HistogramProto) + return target; +} + +size_t HistogramProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.HistogramProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated double bucket_limit = 6 [packed = true]; + { + unsigned int count = static_cast(this->_internal_bucket_limit_size()); + size_t data_size = 8UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _bucket_limit_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated double bucket = 7 [packed = true]; + { + unsigned int count = static_cast(this->_internal_bucket_size()); + size_t data_size = 8UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _bucket_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // double min = 1; + if (!(this->min() <= 0 && this->min() >= 0)) { + total_size += 1 + 8; + } + + // double max = 2; + if (!(this->max() <= 0 && this->max() >= 0)) { + total_size += 1 + 8; + } + + // double num = 3; + if (!(this->num() <= 0 && this->num() >= 0)) { + total_size += 1 + 8; + } + + // double sum = 4; + if (!(this->sum() <= 0 && this->sum() >= 0)) { + total_size += 1 + 8; + } + + // double sum_squares = 5; + if (!(this->sum_squares() <= 0 && this->sum_squares() >= 0)) { + total_size += 1 + 8; + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void HistogramProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.HistogramProto) + GOOGLE_DCHECK_NE(&from, this); + const HistogramProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.HistogramProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.HistogramProto) + MergeFrom(*source); + } +} + +void HistogramProto::MergeFrom(const HistogramProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.HistogramProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + bucket_limit_.MergeFrom(from.bucket_limit_); + bucket_.MergeFrom(from.bucket_); + if (!(from.min() <= 0 && from.min() >= 0)) { + _internal_set_min(from._internal_min()); + } + if (!(from.max() <= 0 && from.max() >= 0)) { + _internal_set_max(from._internal_max()); + } + if (!(from.num() <= 0 && from.num() >= 0)) { + _internal_set_num(from._internal_num()); + } + if (!(from.sum() <= 0 && from.sum() >= 0)) { + _internal_set_sum(from._internal_sum()); + } + if (!(from.sum_squares() <= 0 && from.sum_squares() >= 0)) { + _internal_set_sum_squares(from._internal_sum_squares()); + } +} + +void HistogramProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.HistogramProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void HistogramProto::CopyFrom(const HistogramProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.HistogramProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool HistogramProto::IsInitialized() const { + return true; +} + +void HistogramProto::InternalSwap(HistogramProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + bucket_limit_.InternalSwap(&other->bucket_limit_); + bucket_.InternalSwap(&other->bucket_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(HistogramProto, sum_squares_) + + sizeof(HistogramProto::sum_squares_) + - PROTOBUF_FIELD_OFFSET(HistogramProto, min_)>( + reinterpret_cast(&min_), + reinterpret_cast(&other->min_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata HistogramProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::HistogramProto* Arena::CreateMaybeMessage< ::tensorboard::HistogramProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::HistogramProto >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/histogram.pb.h b/plugins/mindstudio-insight-plugins/proto/histogram.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..967062bc3b82e9380703f72bb741e4f4089850da --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/histogram.pb.h @@ -0,0 +1,514 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: histogram.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_histogram_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_histogram_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_histogram_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_histogram_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_histogram_2eproto; +namespace tensorboard { +class HistogramProto; +class HistogramProtoDefaultTypeInternal; +extern HistogramProtoDefaultTypeInternal _HistogramProto_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::HistogramProto* Arena::CreateMaybeMessage<::tensorboard::HistogramProto>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +// =================================================================== + +class HistogramProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.HistogramProto) */ { + public: + inline HistogramProto() : HistogramProto(nullptr) {} + virtual ~HistogramProto(); + + HistogramProto(const HistogramProto& from); + HistogramProto(HistogramProto&& from) noexcept + : HistogramProto() { + *this = ::std::move(from); + } + + inline HistogramProto& operator=(const HistogramProto& from) { + CopyFrom(from); + return *this; + } + inline HistogramProto& operator=(HistogramProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const HistogramProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const HistogramProto* internal_default_instance() { + return reinterpret_cast( + &_HistogramProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(HistogramProto& a, HistogramProto& b) { + a.Swap(&b); + } + inline void Swap(HistogramProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(HistogramProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline HistogramProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + HistogramProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const HistogramProto& from); + void MergeFrom(const HistogramProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(HistogramProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.HistogramProto"; + } + protected: + explicit HistogramProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_histogram_2eproto); + return ::descriptor_table_histogram_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kBucketLimitFieldNumber = 6, + kBucketFieldNumber = 7, + kMinFieldNumber = 1, + kMaxFieldNumber = 2, + kNumFieldNumber = 3, + kSumFieldNumber = 4, + kSumSquaresFieldNumber = 5, + }; + // repeated double bucket_limit = 6 [packed = true]; + int bucket_limit_size() const; + private: + int _internal_bucket_limit_size() const; + public: + void clear_bucket_limit(); + private: + double _internal_bucket_limit(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_bucket_limit() const; + void _internal_add_bucket_limit(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_bucket_limit(); + public: + double bucket_limit(int index) const; + void set_bucket_limit(int index, double value); + void add_bucket_limit(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + bucket_limit() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_bucket_limit(); + + // repeated double bucket = 7 [packed = true]; + int bucket_size() const; + private: + int _internal_bucket_size() const; + public: + void clear_bucket(); + private: + double _internal_bucket(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_bucket() const; + void _internal_add_bucket(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_bucket(); + public: + double bucket(int index) const; + void set_bucket(int index, double value); + void add_bucket(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + bucket() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_bucket(); + + // double min = 1; + void clear_min(); + double min() const; + void set_min(double value); + private: + double _internal_min() const; + void _internal_set_min(double value); + public: + + // double max = 2; + void clear_max(); + double max() const; + void set_max(double value); + private: + double _internal_max() const; + void _internal_set_max(double value); + public: + + // double num = 3; + void clear_num(); + double num() const; + void set_num(double value); + private: + double _internal_num() const; + void _internal_set_num(double value); + public: + + // double sum = 4; + void clear_sum(); + double sum() const; + void set_sum(double value); + private: + double _internal_sum() const; + void _internal_set_sum(double value); + public: + + // double sum_squares = 5; + void clear_sum_squares(); + double sum_squares() const; + void set_sum_squares(double value); + private: + double _internal_sum_squares() const; + void _internal_set_sum_squares(double value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.HistogramProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > bucket_limit_; + mutable std::atomic _bucket_limit_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > bucket_; + mutable std::atomic _bucket_cached_byte_size_; + double min_; + double max_; + double num_; + double sum_; + double sum_squares_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_histogram_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// HistogramProto + +// double min = 1; +inline void HistogramProto::clear_min() { + min_ = 0; +} +inline double HistogramProto::_internal_min() const { + return min_; +} +inline double HistogramProto::min() const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.min) + return _internal_min(); +} +inline void HistogramProto::_internal_set_min(double value) { + + min_ = value; +} +inline void HistogramProto::set_min(double value) { + _internal_set_min(value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.min) +} + +// double max = 2; +inline void HistogramProto::clear_max() { + max_ = 0; +} +inline double HistogramProto::_internal_max() const { + return max_; +} +inline double HistogramProto::max() const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.max) + return _internal_max(); +} +inline void HistogramProto::_internal_set_max(double value) { + + max_ = value; +} +inline void HistogramProto::set_max(double value) { + _internal_set_max(value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.max) +} + +// double num = 3; +inline void HistogramProto::clear_num() { + num_ = 0; +} +inline double HistogramProto::_internal_num() const { + return num_; +} +inline double HistogramProto::num() const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.num) + return _internal_num(); +} +inline void HistogramProto::_internal_set_num(double value) { + + num_ = value; +} +inline void HistogramProto::set_num(double value) { + _internal_set_num(value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.num) +} + +// double sum = 4; +inline void HistogramProto::clear_sum() { + sum_ = 0; +} +inline double HistogramProto::_internal_sum() const { + return sum_; +} +inline double HistogramProto::sum() const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.sum) + return _internal_sum(); +} +inline void HistogramProto::_internal_set_sum(double value) { + + sum_ = value; +} +inline void HistogramProto::set_sum(double value) { + _internal_set_sum(value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.sum) +} + +// double sum_squares = 5; +inline void HistogramProto::clear_sum_squares() { + sum_squares_ = 0; +} +inline double HistogramProto::_internal_sum_squares() const { + return sum_squares_; +} +inline double HistogramProto::sum_squares() const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.sum_squares) + return _internal_sum_squares(); +} +inline void HistogramProto::_internal_set_sum_squares(double value) { + + sum_squares_ = value; +} +inline void HistogramProto::set_sum_squares(double value) { + _internal_set_sum_squares(value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.sum_squares) +} + +// repeated double bucket_limit = 6 [packed = true]; +inline int HistogramProto::_internal_bucket_limit_size() const { + return bucket_limit_.size(); +} +inline int HistogramProto::bucket_limit_size() const { + return _internal_bucket_limit_size(); +} +inline void HistogramProto::clear_bucket_limit() { + bucket_limit_.Clear(); +} +inline double HistogramProto::_internal_bucket_limit(int index) const { + return bucket_limit_.Get(index); +} +inline double HistogramProto::bucket_limit(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.bucket_limit) + return _internal_bucket_limit(index); +} +inline void HistogramProto::set_bucket_limit(int index, double value) { + bucket_limit_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.bucket_limit) +} +inline void HistogramProto::_internal_add_bucket_limit(double value) { + bucket_limit_.Add(value); +} +inline void HistogramProto::add_bucket_limit(double value) { + _internal_add_bucket_limit(value); + // @@protoc_insertion_point(field_add:tensorboard.HistogramProto.bucket_limit) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +HistogramProto::_internal_bucket_limit() const { + return bucket_limit_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +HistogramProto::bucket_limit() const { + // @@protoc_insertion_point(field_list:tensorboard.HistogramProto.bucket_limit) + return _internal_bucket_limit(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +HistogramProto::_internal_mutable_bucket_limit() { + return &bucket_limit_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +HistogramProto::mutable_bucket_limit() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.HistogramProto.bucket_limit) + return _internal_mutable_bucket_limit(); +} + +// repeated double bucket = 7 [packed = true]; +inline int HistogramProto::_internal_bucket_size() const { + return bucket_.size(); +} +inline int HistogramProto::bucket_size() const { + return _internal_bucket_size(); +} +inline void HistogramProto::clear_bucket() { + bucket_.Clear(); +} +inline double HistogramProto::_internal_bucket(int index) const { + return bucket_.Get(index); +} +inline double HistogramProto::bucket(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.HistogramProto.bucket) + return _internal_bucket(index); +} +inline void HistogramProto::set_bucket(int index, double value) { + bucket_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.HistogramProto.bucket) +} +inline void HistogramProto::_internal_add_bucket(double value) { + bucket_.Add(value); +} +inline void HistogramProto::add_bucket(double value) { + _internal_add_bucket(value); + // @@protoc_insertion_point(field_add:tensorboard.HistogramProto.bucket) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +HistogramProto::_internal_bucket() const { + return bucket_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +HistogramProto::bucket() const { + // @@protoc_insertion_point(field_list:tensorboard.HistogramProto.bucket) + return _internal_bucket(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +HistogramProto::_internal_mutable_bucket() { + return &bucket_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +HistogramProto::mutable_bucket() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.HistogramProto.bucket) + return _internal_mutable_bucket(); +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_histogram_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/histogram.proto b/plugins/mindstudio-insight-plugins/proto/histogram.proto new file mode 100644 index 0000000000000000000000000000000000000000..568a90b73a1adee038a6eaf09f759f413f97a3f4 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/histogram.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package tensorboard; + +option cc_enable_arenas = true; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/google/tsl/tsl/go/core/protobuf/summary_go_proto"; + +// Serialization format for histogram module in +// tsl/lib/histogram/histogram.h +message HistogramProto { + double min = 1; + double max = 2; + double num = 3; + double sum = 4; + double sum_squares = 5; + + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + repeated double bucket_limit = 6 [packed = true]; + repeated double bucket = 7 [packed = true]; +} diff --git a/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.pb.cc b/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e9d3d4b5670300b655b0039c2fe26acd4a2f05f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.pb.cc @@ -0,0 +1,6837 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mindspore_anf_ir.proto + +#include "mindspore_anf_ir.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<4> scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_InputProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_OperatorProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_OutputProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto; +namespace mindspore { +namespace irpb { +class ValueProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ValueProto_default_instance_; +class AttributeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _AttributeProto_default_instance_; +class NamedValueProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _NamedValueProto_default_instance_; +class TensorShapeProto_DimensionDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorShapeProto_Dimension_default_instance_; +class TensorShapeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorShapeProto_default_instance_; +class TypeProto_TensorDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_Tensor_default_instance_; +class TypeProto_SequenceDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_Sequence_default_instance_; +class TypeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + const ::mindspore::irpb::TypeProto_Tensor* tensor_type_; + const ::mindspore::irpb::TypeProto_Sequence* sequence_type_; +} _TypeProto_default_instance_; +class ParameterProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ParameterProto_default_instance_; +class OutputProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _OutputProto_default_instance_; +class InputProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _InputProto_default_instance_; +class NodeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _NodeProto_default_instance_; +class ModelProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ModelProto_default_instance_; +class OperatorProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _OperatorProto_default_instance_; +class OperatorSetProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _OperatorSetProto_default_instance_; +class GraphProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _GraphProto_default_instance_; +class TensorProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorProto_default_instance_; +} // namespace irpb +} // namespace mindspore +static void InitDefaultsscc_info_AttributeProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_ValueProto_default_instance_; + new (ptr) ::mindspore::irpb::ValueProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::mindspore::irpb::_AttributeProto_default_instance_; + new (ptr) ::mindspore::irpb::AttributeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::mindspore::irpb::_NamedValueProto_default_instance_; + new (ptr) ::mindspore::irpb::NamedValueProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::mindspore::irpb::_ParameterProto_default_instance_; + new (ptr) ::mindspore::irpb::ParameterProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::mindspore::irpb::_NodeProto_default_instance_; + new (ptr) ::mindspore::irpb::NodeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::mindspore::irpb::_GraphProto_default_instance_; + new (ptr) ::mindspore::irpb::GraphProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::ValueProto::InitAsDefaultInstance(); + ::mindspore::irpb::AttributeProto::InitAsDefaultInstance(); + ::mindspore::irpb::NamedValueProto::InitAsDefaultInstance(); + ::mindspore::irpb::ParameterProto::InitAsDefaultInstance(); + ::mindspore::irpb::NodeProto::InitAsDefaultInstance(); + ::mindspore::irpb::GraphProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<4> scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 4, 0, InitDefaultsscc_info_AttributeProto_mindspore_5fanf_5fir_2eproto}, { + &scc_info_OutputProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_InputProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TensorProto_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_InputProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_InputProto_default_instance_; + new (ptr) ::mindspore::irpb::InputProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::InputProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_InputProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_InputProto_mindspore_5fanf_5fir_2eproto}, {}}; + +static void InitDefaultsscc_info_ModelProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_ModelProto_default_instance_; + new (ptr) ::mindspore::irpb::ModelProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::ModelProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_ModelProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 2, 0, InitDefaultsscc_info_ModelProto_mindspore_5fanf_5fir_2eproto}, { + &scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_OperatorProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_OperatorProto_default_instance_; + new (ptr) ::mindspore::irpb::OperatorProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::OperatorProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_OperatorProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_OperatorProto_mindspore_5fanf_5fir_2eproto}, {}}; + +static void InitDefaultsscc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_OperatorSetProto_default_instance_; + new (ptr) ::mindspore::irpb::OperatorSetProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::OperatorSetProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto}, { + &scc_info_OperatorProto_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_OutputProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_OutputProto_default_instance_; + new (ptr) ::mindspore::irpb::OutputProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::OutputProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_OutputProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_OutputProto_mindspore_5fanf_5fir_2eproto}, { + &scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_TensorProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_TensorProto_default_instance_; + new (ptr) ::mindspore::irpb::TensorProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::TensorProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TensorProto_mindspore_5fanf_5fir_2eproto}, {}}; + +static void InitDefaultsscc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_TensorShapeProto_default_instance_; + new (ptr) ::mindspore::irpb::TensorShapeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::TensorShapeProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto}, { + &scc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_TensorShapeProto_Dimension_default_instance_; + new (ptr) ::mindspore::irpb::TensorShapeProto_Dimension(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::TensorShapeProto_Dimension::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto}, {}}; + +static void InitDefaultsscc_info_TypeProto_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_TypeProto_Sequence_default_instance_; + new (ptr) ::mindspore::irpb::TypeProto_Sequence(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::mindspore::irpb::_TypeProto_default_instance_; + new (ptr) ::mindspore::irpb::TypeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::TypeProto_Sequence::InitAsDefaultInstance(); + ::mindspore::irpb::TypeProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TypeProto_mindspore_5fanf_5fir_2eproto}, { + &scc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_TypeProto_Tensor_default_instance_; + new (ptr) ::mindspore::irpb::TypeProto_Tensor(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::TypeProto_Tensor::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto}, { + &scc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto.base,}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_mindspore_5fanf_5fir_2eproto[17]; +static const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* file_level_enum_descriptors_mindspore_5fanf_5fir_2eproto[3]; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_mindspore_5fanf_5fir_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_mindspore_5fanf_5fir_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, dtype_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, bool_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, int_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, uint_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, float_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, double_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, str_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, tensor_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, graph_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, bool_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, int_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, uint_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, float_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, double_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, str_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, tensor_vals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, graphs_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, values_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, dict_val_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ValueProto, type_val_), + 4, + 5, + 6, + 7, + 9, + 8, + 0, + 1, + 2, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + 3, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::AttributeProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::AttributeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::AttributeProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::AttributeProto, value_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NamedValueProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NamedValueProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NamedValueProto, key_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NamedValueProto, value_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorShapeProto_Dimension, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorShapeProto_Dimension, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorShapeProto_Dimension, size_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorShapeProto_Dimension, name_), + 1, + 0, + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorShapeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorShapeProto, dim_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto_Tensor, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto_Tensor, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto_Tensor, elem_type_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto_Tensor, shape_), + 1, + 0, + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto_Sequence, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto_Sequence, elem_types_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto, data_type_), + offsetof(::mindspore::irpb::TypeProtoDefaultTypeInternal, tensor_type_), + offsetof(::mindspore::irpb::TypeProtoDefaultTypeInternal, sequence_type_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TypeProto, value_), + 0, + ~0u, + ~0u, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ParameterProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ParameterProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ParameterProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ParameterProto, type_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ParameterProto, default_val_), + 0, + 1, + 2, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OutputProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OutputProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OutputProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OutputProto, type_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::InputProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::InputProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::InputProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::InputProto, type_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, input_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, op_type_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, scope_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, attribute_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, output_type_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, output_i_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, full_name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::NodeProto, instance_name_), + ~0u, + 0, + 1, + 2, + ~0u, + 5, + 6, + 3, + 4, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, ir_version_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, domain_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, model_version_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, graph_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::ModelProto, metadata_operators_), + 3, + 0, + 4, + 1, + 2, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorProto, config_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorProto, obj_info_), + 0, + 1, + 2, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorSetProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorSetProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorSetProto, operators_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::OperatorSetProto, summary_), + ~0u, + 0, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, node_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, name_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, parameters_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, outputs_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::GraphProto, const_vals_), + ~0u, + 0, + ~0u, + ~0u, + ~0u, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, dims_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, data_type_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, float_data_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, int32_data_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, int64_data_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, double_data_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, uint64_data_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::TensorProto, raw_data_), + ~0u, + 1, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + 0, +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, 25, sizeof(::mindspore::irpb::ValueProto)}, + { 45, 52, sizeof(::mindspore::irpb::AttributeProto)}, + { 54, 61, sizeof(::mindspore::irpb::NamedValueProto)}, + { 63, 70, sizeof(::mindspore::irpb::TensorShapeProto_Dimension)}, + { 72, -1, sizeof(::mindspore::irpb::TensorShapeProto)}, + { 78, 85, sizeof(::mindspore::irpb::TypeProto_Tensor)}, + { 87, -1, sizeof(::mindspore::irpb::TypeProto_Sequence)}, + { 93, 102, sizeof(::mindspore::irpb::TypeProto)}, + { 105, 113, sizeof(::mindspore::irpb::ParameterProto)}, + { 116, 123, sizeof(::mindspore::irpb::OutputProto)}, + { 125, 132, sizeof(::mindspore::irpb::InputProto)}, + { 134, 148, sizeof(::mindspore::irpb::NodeProto)}, + { 157, 167, sizeof(::mindspore::irpb::ModelProto)}, + { 172, 180, sizeof(::mindspore::irpb::OperatorProto)}, + { 183, 190, sizeof(::mindspore::irpb::OperatorSetProto)}, + { 192, 202, sizeof(::mindspore::irpb::GraphProto)}, + { 207, 220, sizeof(::mindspore::irpb::TensorProto)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::mindspore::irpb::_ValueProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_AttributeProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_NamedValueProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_TensorShapeProto_Dimension_default_instance_), + reinterpret_cast(&::mindspore::irpb::_TensorShapeProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_TypeProto_Tensor_default_instance_), + reinterpret_cast(&::mindspore::irpb::_TypeProto_Sequence_default_instance_), + reinterpret_cast(&::mindspore::irpb::_TypeProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_ParameterProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_OutputProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_InputProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_NodeProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_ModelProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_OperatorProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_OperatorSetProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_GraphProto_default_instance_), + reinterpret_cast(&::mindspore::irpb::_TensorProto_default_instance_), +}; + +const char descriptor_table_protodef_mindspore_5fanf_5fir_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\026mindspore_anf_ir.proto\022\016mindspore.irpb" + "\"\333\004\n\nValueProto\022\'\n\005dtype\030\001 \001(\0162\030.mindspo" + "re.irpb.DataType\022\020\n\010bool_val\030\002 \001(\010\022\017\n\007in" + "t_val\030\003 \001(\003\022\020\n\010uint_val\030\004 \001(\004\022\021\n\tfloat_v" + "al\030\005 \001(\002\022\022\n\ndouble_val\030\006 \001(\001\022\017\n\007str_val\030" + "\007 \001(\t\022/\n\ntensor_val\030\010 \001(\0132\033.mindspore.ir" + "pb.TensorProto\022)\n\005graph\030\t \001(\0132\032.mindspor" + "e.irpb.GraphProto\022\021\n\tbool_vals\030\n \003(\010\022\020\n\010" + "int_vals\030\013 \003(\003\022\021\n\tuint_vals\030\014 \003(\004\022\022\n\nflo" + "at_vals\030\r \003(\002\022\023\n\013double_vals\030\016 \003(\001\022\020\n\010st" + "r_vals\030\017 \003(\t\0220\n\013tensor_vals\030\020 \003(\0132\033.mind" + "spore.irpb.TensorProto\022*\n\006graphs\030\021 \003(\0132\032" + ".mindspore.irpb.GraphProto\022*\n\006values\030\022 \003" + "(\0132\032.mindspore.irpb.ValueProto\0221\n\010dict_v" + "al\030\023 \003(\0132\037.mindspore.irpb.NamedValueProt" + "o\022+\n\010type_val\030\024 \001(\0132\031.mindspore.irpb.Typ" + "eProto\"I\n\016AttributeProto\022\014\n\004name\030\001 \001(\t\022)" + "\n\005value\030\002 \001(\0132\032.mindspore.irpb.ValueProt" + "o\"I\n\017NamedValueProto\022\013\n\003key\030\001 \001(\t\022)\n\005val" + "ue\030\002 \001(\0132\032.mindspore.irpb.ValueProto\"t\n\020" + "TensorShapeProto\0227\n\003dim\030\001 \003(\0132*.mindspor" + "e.irpb.TensorShapeProto.Dimension\032\'\n\tDim" + "ension\022\014\n\004size\030\001 \001(\003\022\014\n\004name\030\002 \001(\t\"\332\002\n\tT" + "ypeProto\022+\n\tdata_type\030\001 \001(\0162\030.mindspore." + "irpb.DataType\0227\n\013tensor_type\030\002 \001(\0132 .min" + "dspore.irpb.TypeProto.TensorH\000\022;\n\rsequen" + "ce_type\030\003 \001(\0132\".mindspore.irpb.TypeProto" + ".SequenceH\000\032f\n\006Tensor\022+\n\telem_type\030\001 \001(\016" + "2\030.mindspore.irpb.DataType\022/\n\005shape\030\002 \001(" + "\0132 .mindspore.irpb.TensorShapeProto\0329\n\010S" + "equence\022-\n\nelem_types\030\001 \003(\0132\031.mindspore." + "irpb.TypeProtoB\007\n\005value\"x\n\016ParameterProt" + "o\022\014\n\004name\030\001 \001(\t\022\'\n\004type\030\002 \001(\0132\031.mindspor" + "e.irpb.TypeProto\022/\n\013default_val\030\003 \001(\0132\032." + "mindspore.irpb.ValueProto\"D\n\013OutputProto" + "\022\014\n\004name\030\001 \001(\t\022\'\n\004type\030\002 \001(\0132\031.mindspore" + ".irpb.TypeProto\"z\n\nInputProto\022\014\n\004name\030\001 " + "\001(\t\0221\n\004type\030\002 \001(\0162#.mindspore.irpb.Input" + "Proto.EdgeType\"+\n\010EdgeType\022\r\n\tDATA_EDGE\020" + "\000\022\020\n\014CONTROL_EDGE\020\001\"\203\002\n\tNodeProto\022)\n\005inp" + "ut\030\001 \003(\0132\032.mindspore.irpb.InputProto\022\014\n\004" + "name\030\002 \001(\t\022\017\n\007op_type\030\003 \001(\t\022\r\n\005scope\030\004 \001" + "(\t\0221\n\tattribute\030\005 \003(\0132\036.mindspore.irpb.A" + "ttributeProto\022.\n\013output_type\030\006 \001(\0132\031.min" + "dspore.irpb.TypeProto\022\020\n\010output_i\030\007 \001(\004\022" + "\021\n\tfull_name\030\010 \001(\t\022\025\n\rinstance_name\030\n \001(" + "\t\"\260\001\n\nModelProto\022\022\n\nir_version\030\001 \001(\003\022\016\n\006" + "domain\030\002 \001(\t\022\025\n\rmodel_version\030\003 \001(\003\022)\n\005g" + "raph\030\004 \001(\0132\032.mindspore.irpb.GraphProto\022<" + "\n\022metadata_operators\030\005 \001(\0132 .mindspore.i" + "rpb.OperatorSetProto\"\?\n\rOperatorProto\022\014\n" + "\004name\030\001 \001(\t\022\016\n\006config\030\002 \001(\014\022\020\n\010obj_info\030" + "\003 \001(\014\"U\n\020OperatorSetProto\0220\n\toperators\030\001" + " \003(\0132\035.mindspore.irpb.OperatorProto\022\017\n\007s" + "ummary\030\002 \001(\t\"\332\001\n\nGraphProto\022\'\n\004node\030\001 \003(" + "\0132\031.mindspore.irpb.NodeProto\022\014\n\004name\030\002 \001" + "(\t\0222\n\nparameters\030\003 \003(\0132\036.mindspore.irpb." + "ParameterProto\022,\n\007outputs\030\004 \003(\0132\033.mindsp" + "ore.irpb.OutputProto\0223\n\nconst_vals\030\005 \003(\013" + "2\037.mindspore.irpb.NamedValueProto\"\324\001\n\013Te" + "nsorProto\022\014\n\004dims\030\001 \003(\003\022+\n\tdata_type\030\002 \001" + "(\0162\030.mindspore.irpb.DataType\022\026\n\nfloat_da" + "ta\030\003 \003(\002B\002\020\001\022\026\n\nint32_data\030\004 \003(\005B\002\020\001\022\026\n\n" + "int64_data\030\005 \003(\003B\002\020\001\022\027\n\013double_data\030\006 \003(" + "\001B\002\020\001\022\027\n\013uint64_data\030\007 \003(\004B\002\020\001\022\020\n\010raw_da" + "ta\030\010 \001(\014*/\n\007Version\022\024\n\020UNKNOWWN_VERSION\020" + "\000\022\016\n\nIR_VERSION\020\001*\211\006\n\010DataType\022\020\n\014DT_UND" + "EFINED\020\000\022\013\n\007DT_BOOL\020\001\022\013\n\007DT_INT8\020\002\022\014\n\010DT" + "_INT16\020\003\022\014\n\010DT_INT32\020\004\022\014\n\010DT_INT64\020\005\022\014\n\010" + "DT_UINT8\020\006\022\r\n\tDT_UINT16\020\007\022\r\n\tDT_UINT32\020\010" + "\022\r\n\tDT_UINT64\020\t\022\016\n\nDT_FLOAT16\020\n\022\016\n\nDT_FL" + "OAT32\020\013\022\016\n\nDT_FLOAT64\020\014\022\r\n\tDT_STRING\020\r\022\r" + "\n\tDT_TENSOR\020\016\022\014\n\010DT_GRAPH\020\017\022\014\n\010DT_BOOLS\020" + "\020\022\014\n\010DT_INTS8\020\021\022\r\n\tDT_INTS16\020\022\022\r\n\tDT_INT" + "S32\020\023\022\r\n\tDT_INTS64\020\024\022\r\n\tDT_UINTS8\020\025\022\016\n\nD" + "T_UINTS16\020\026\022\016\n\nDT_UINTS32\020\027\022\016\n\nDT_UINTS6" + "4\020\030\022\017\n\013DT_FLOATS16\020\031\022\017\n\013DT_FLOATS32\020\032\022\017\n" + "\013DT_FLOATS64\020\033\022\016\n\nDT_STRINGS\020\034\022\016\n\nDT_TEN" + "SORS\020\035\022\r\n\tDT_GRAPHS\020\036\022\014\n\010DT_TUPLE\020\037\022\013\n\007D" + "T_LIST\020 \022\013\n\007DT_DICT\020!\022\013\n\007DT_NONE\020\"\022\017\n\013DT" + "_SYM_INST\020#\022\017\n\013DT_BASE_INT\020$\022\020\n\014DT_BASE_" + "UINT\020%\022\021\n\rDT_BASE_FLOAT\020&\022\013\n\007DT_TYPE\020\'\022\n" + "\n\006DT_ANY\020(\022\r\n\tDT_REFKEY\020)\022\n\n\006DT_REF\020*\022\020\n" + "\014DT_COMPLEX64\020+\022\021\n\rDT_COMPLEX128\020,\022\023\n\017DT" + "_BASE_COMPLEX\020-\022\017\n\013DT_BFLOAT16\020.\022\020\n\014DT_B" + "FLOATS16\020/\022\013\n\007DT_INT4\0200\022\014\n\010DT_SLICE\0201" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_mindspore_5fanf_5fir_2eproto_deps[1] = { +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_mindspore_5fanf_5fir_2eproto_sccs[11] = { + &scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_InputProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_ModelProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_OperatorProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_OutputProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TensorProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_mindspore_5fanf_5fir_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_mindspore_5fanf_5fir_2eproto = { + false, false, descriptor_table_protodef_mindspore_5fanf_5fir_2eproto, "mindspore_anf_ir.proto", 3437, + &descriptor_table_mindspore_5fanf_5fir_2eproto_once, descriptor_table_mindspore_5fanf_5fir_2eproto_sccs, descriptor_table_mindspore_5fanf_5fir_2eproto_deps, 11, 0, + schemas, file_default_instances, TableStruct_mindspore_5fanf_5fir_2eproto::offsets, + file_level_metadata_mindspore_5fanf_5fir_2eproto, 17, file_level_enum_descriptors_mindspore_5fanf_5fir_2eproto, file_level_service_descriptors_mindspore_5fanf_5fir_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_mindspore_5fanf_5fir_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_mindspore_5fanf_5fir_2eproto)), true); +namespace mindspore { +namespace irpb { +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* InputProto_EdgeType_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_mindspore_5fanf_5fir_2eproto); + return file_level_enum_descriptors_mindspore_5fanf_5fir_2eproto[0]; +} +bool InputProto_EdgeType_IsValid(int value) { + switch (value) { + case 0: + case 1: + return true; + default: + return false; + } +} + +#if (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +constexpr InputProto_EdgeType InputProto::DATA_EDGE; +constexpr InputProto_EdgeType InputProto::CONTROL_EDGE; +constexpr InputProto_EdgeType InputProto::EdgeType_MIN; +constexpr InputProto_EdgeType InputProto::EdgeType_MAX; +constexpr int InputProto::EdgeType_ARRAYSIZE; +#endif // (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Version_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_mindspore_5fanf_5fir_2eproto); + return file_level_enum_descriptors_mindspore_5fanf_5fir_2eproto[1]; +} +bool Version_IsValid(int value) { + switch (value) { + case 0: + case 1: + return true; + default: + return false; + } +} + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataType_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_mindspore_5fanf_5fir_2eproto); + return file_level_enum_descriptors_mindspore_5fanf_5fir_2eproto[2]; +} +bool DataType_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + case 9: + case 10: + case 11: + case 12: + case 13: + case 14: + case 15: + case 16: + case 17: + case 18: + case 19: + case 20: + case 21: + case 22: + case 23: + case 24: + case 25: + case 26: + case 27: + case 28: + case 29: + case 30: + case 31: + case 32: + case 33: + case 34: + case 35: + case 36: + case 37: + case 38: + case 39: + case 40: + case 41: + case 42: + case 43: + case 44: + case 45: + case 46: + case 47: + case 48: + case 49: + return true; + default: + return false; + } +} + + +// =================================================================== + +void ValueProto::InitAsDefaultInstance() { + ::mindspore::irpb::_ValueProto_default_instance_._instance.get_mutable()->tensor_val_ = const_cast< ::mindspore::irpb::TensorProto*>( + ::mindspore::irpb::TensorProto::internal_default_instance()); + ::mindspore::irpb::_ValueProto_default_instance_._instance.get_mutable()->graph_ = const_cast< ::mindspore::irpb::GraphProto*>( + ::mindspore::irpb::GraphProto::internal_default_instance()); + ::mindspore::irpb::_ValueProto_default_instance_._instance.get_mutable()->type_val_ = const_cast< ::mindspore::irpb::TypeProto*>( + ::mindspore::irpb::TypeProto::internal_default_instance()); +} +class ValueProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_dtype(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } + static void set_has_bool_val(HasBits* has_bits) { + (*has_bits)[0] |= 32u; + } + static void set_has_int_val(HasBits* has_bits) { + (*has_bits)[0] |= 64u; + } + static void set_has_uint_val(HasBits* has_bits) { + (*has_bits)[0] |= 128u; + } + static void set_has_float_val(HasBits* has_bits) { + (*has_bits)[0] |= 512u; + } + static void set_has_double_val(HasBits* has_bits) { + (*has_bits)[0] |= 256u; + } + static void set_has_str_val(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::TensorProto& tensor_val(const ValueProto* msg); + static void set_has_tensor_val(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::GraphProto& graph(const ValueProto* msg); + static void set_has_graph(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static const ::mindspore::irpb::TypeProto& type_val(const ValueProto* msg); + static void set_has_type_val(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } +}; + +const ::mindspore::irpb::TensorProto& +ValueProto::_Internal::tensor_val(const ValueProto* msg) { + return *msg->tensor_val_; +} +const ::mindspore::irpb::GraphProto& +ValueProto::_Internal::graph(const ValueProto* msg) { + return *msg->graph_; +} +const ::mindspore::irpb::TypeProto& +ValueProto::_Internal::type_val(const ValueProto* msg) { + return *msg->type_val_; +} +ValueProto::ValueProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + bool_vals_(arena), + int_vals_(arena), + uint_vals_(arena), + float_vals_(arena), + double_vals_(arena), + str_vals_(arena), + tensor_vals_(arena), + graphs_(arena), + values_(arena), + dict_val_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.ValueProto) +} +ValueProto::ValueProto(const ValueProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + bool_vals_(from.bool_vals_), + int_vals_(from.int_vals_), + uint_vals_(from.uint_vals_), + float_vals_(from.float_vals_), + double_vals_(from.double_vals_), + str_vals_(from.str_vals_), + tensor_vals_(from.tensor_vals_), + graphs_(from.graphs_), + values_(from.values_), + dict_val_(from.dict_val_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + str_val_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_str_val()) { + str_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_str_val(), + GetArena()); + } + if (from._internal_has_tensor_val()) { + tensor_val_ = new ::mindspore::irpb::TensorProto(*from.tensor_val_); + } else { + tensor_val_ = nullptr; + } + if (from._internal_has_graph()) { + graph_ = new ::mindspore::irpb::GraphProto(*from.graph_); + } else { + graph_ = nullptr; + } + if (from._internal_has_type_val()) { + type_val_ = new ::mindspore::irpb::TypeProto(*from.type_val_); + } else { + type_val_ = nullptr; + } + ::memcpy(&dtype_, &from.dtype_, + static_cast(reinterpret_cast(&float_val_) - + reinterpret_cast(&dtype_)) + sizeof(float_val_)); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.ValueProto) +} + +void ValueProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + str_val_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&tensor_val_, 0, static_cast( + reinterpret_cast(&float_val_) - + reinterpret_cast(&tensor_val_)) + sizeof(float_val_)); +} + +ValueProto::~ValueProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.ValueProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ValueProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + str_val_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete tensor_val_; + if (this != internal_default_instance()) delete graph_; + if (this != internal_default_instance()) delete type_val_; +} + +void ValueProto::ArenaDtor(void* object) { + ValueProto* _this = reinterpret_cast< ValueProto* >(object); + (void)_this; +} +void ValueProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ValueProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ValueProto& ValueProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void ValueProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.ValueProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + bool_vals_.Clear(); + int_vals_.Clear(); + uint_vals_.Clear(); + float_vals_.Clear(); + double_vals_.Clear(); + str_vals_.Clear(); + tensor_vals_.Clear(); + graphs_.Clear(); + values_.Clear(); + dict_val_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + str_val_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(tensor_val_ != nullptr); + tensor_val_->Clear(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(graph_ != nullptr); + graph_->Clear(); + } + if (cached_has_bits & 0x00000008u) { + GOOGLE_DCHECK(type_val_ != nullptr); + type_val_->Clear(); + } + } + if (cached_has_bits & 0x000000f0u) { + ::memset(&dtype_, 0, static_cast( + reinterpret_cast(&uint_val_) - + reinterpret_cast(&dtype_)) + sizeof(uint_val_)); + } + if (cached_has_bits & 0x00000300u) { + ::memset(&double_val_, 0, static_cast( + reinterpret_cast(&float_val_) - + reinterpret_cast(&double_val_)) + sizeof(float_val_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ValueProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .mindspore.irpb.DataType dtype = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::mindspore::irpb::DataType_IsValid(val))) { + _internal_set_dtype(static_cast<::mindspore::irpb::DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(1, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional bool bool_val = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_bool_val(&has_bits); + bool_val_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 int_val = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_int_val(&has_bits); + int_val_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional uint64 uint_val = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32)) { + _Internal::set_has_uint_val(&has_bits); + uint_val_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional float float_val = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 45)) { + _Internal::set_has_float_val(&has_bits); + float_val_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // optional double double_val = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 49)) { + _Internal::set_has_double_val(&has_bits); + double_val_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // optional string str_val = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + auto str = _internal_mutable_str_val(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.ValueProto.str_val"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TensorProto tensor_val = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr = ctx->ParseMessage(_internal_mutable_tensor_val(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.GraphProto graph = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + ptr = ctx->ParseMessage(_internal_mutable_graph(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated bool bool_vals = 10; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 80)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_bool_vals(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<80>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedBoolParser(_internal_mutable_bool_vals(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated int64 int_vals = 11; + case 11: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 88)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_int_vals(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<88>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 90) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_int_vals(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated uint64 uint_vals = 12; + case 12: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 96)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_uint_vals(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<96>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 98) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedUInt64Parser(_internal_mutable_uint_vals(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float float_vals = 13; + case 13: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 109)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_float_vals(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<109>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 106) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_float_vals(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated double double_vals = 14; + case 14: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 113)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_double_vals(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<113>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 114) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_double_vals(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated string str_vals = 15; + case 15: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 122)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_str_vals(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.ValueProto.str_vals"); + #endif // !NDEBUG + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<122>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.TensorProto tensor_vals = 16; + case 16: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 130)) { + ptr -= 2; + do { + ptr += 2; + ptr = ctx->ParseMessage(_internal_add_tensor_vals(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<130>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.GraphProto graphs = 17; + case 17: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 138)) { + ptr -= 2; + do { + ptr += 2; + ptr = ctx->ParseMessage(_internal_add_graphs(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<138>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.ValueProto values = 18; + case 18: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 146)) { + ptr -= 2; + do { + ptr += 2; + ptr = ctx->ParseMessage(_internal_add_values(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<146>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.NamedValueProto dict_val = 19; + case 19: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 154)) { + ptr -= 2; + do { + ptr += 2; + ptr = ctx->ParseMessage(_internal_add_dict_val(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<154>(ptr)); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TypeProto type_val = 20; + case 20: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 162)) { + ptr = ctx->ParseMessage(_internal_mutable_type_val(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ValueProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.ValueProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.DataType dtype = 1; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_dtype(), target); + } + + // optional bool bool_val = 2; + if (cached_has_bits & 0x00000020u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(2, this->_internal_bool_val(), target); + } + + // optional int64 int_val = 3; + if (cached_has_bits & 0x00000040u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->_internal_int_val(), target); + } + + // optional uint64 uint_val = 4; + if (cached_has_bits & 0x00000080u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt64ToArray(4, this->_internal_uint_val(), target); + } + + // optional float float_val = 5; + if (cached_has_bits & 0x00000200u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(5, this->_internal_float_val(), target); + } + + // optional double double_val = 6; + if (cached_has_bits & 0x00000100u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(6, this->_internal_double_val(), target); + } + + // optional string str_val = 7; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_str_val().data(), static_cast(this->_internal_str_val().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.ValueProto.str_val"); + target = stream->WriteStringMaybeAliased( + 7, this->_internal_str_val(), target); + } + + // optional .mindspore.irpb.TensorProto tensor_val = 8; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 8, _Internal::tensor_val(this), target, stream); + } + + // optional .mindspore.irpb.GraphProto graph = 9; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 9, _Internal::graph(this), target, stream); + } + + // repeated bool bool_vals = 10; + for (int i = 0, n = this->_internal_bool_vals_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(10, this->_internal_bool_vals(i), target); + } + + // repeated int64 int_vals = 11; + for (int i = 0, n = this->_internal_int_vals_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(11, this->_internal_int_vals(i), target); + } + + // repeated uint64 uint_vals = 12; + for (int i = 0, n = this->_internal_uint_vals_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt64ToArray(12, this->_internal_uint_vals(i), target); + } + + // repeated float float_vals = 13; + for (int i = 0, n = this->_internal_float_vals_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(13, this->_internal_float_vals(i), target); + } + + // repeated double double_vals = 14; + for (int i = 0, n = this->_internal_double_vals_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(14, this->_internal_double_vals(i), target); + } + + // repeated string str_vals = 15; + for (int i = 0, n = this->_internal_str_vals_size(); i < n; i++) { + const auto& s = this->_internal_str_vals(i); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + s.data(), static_cast(s.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.ValueProto.str_vals"); + target = stream->WriteString(15, s, target); + } + + // repeated .mindspore.irpb.TensorProto tensor_vals = 16; + for (unsigned int i = 0, + n = static_cast(this->_internal_tensor_vals_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(16, this->_internal_tensor_vals(i), target, stream); + } + + // repeated .mindspore.irpb.GraphProto graphs = 17; + for (unsigned int i = 0, + n = static_cast(this->_internal_graphs_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(17, this->_internal_graphs(i), target, stream); + } + + // repeated .mindspore.irpb.ValueProto values = 18; + for (unsigned int i = 0, + n = static_cast(this->_internal_values_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(18, this->_internal_values(i), target, stream); + } + + // repeated .mindspore.irpb.NamedValueProto dict_val = 19; + for (unsigned int i = 0, + n = static_cast(this->_internal_dict_val_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(19, this->_internal_dict_val(i), target, stream); + } + + // optional .mindspore.irpb.TypeProto type_val = 20; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 20, _Internal::type_val(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.ValueProto) + return target; +} + +size_t ValueProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.ValueProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated bool bool_vals = 10; + { + unsigned int count = static_cast(this->_internal_bool_vals_size()); + size_t data_size = 1UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_bool_vals_size()); + total_size += data_size; + } + + // repeated int64 int_vals = 11; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->int_vals_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_int_vals_size()); + total_size += data_size; + } + + // repeated uint64 uint_vals = 12; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + UInt64Size(this->uint_vals_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_uint_vals_size()); + total_size += data_size; + } + + // repeated float float_vals = 13; + { + unsigned int count = static_cast(this->_internal_float_vals_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_float_vals_size()); + total_size += data_size; + } + + // repeated double double_vals = 14; + { + unsigned int count = static_cast(this->_internal_double_vals_size()); + size_t data_size = 8UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_double_vals_size()); + total_size += data_size; + } + + // repeated string str_vals = 15; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(str_vals_.size()); + for (int i = 0, n = str_vals_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + str_vals_.Get(i)); + } + + // repeated .mindspore.irpb.TensorProto tensor_vals = 16; + total_size += 2UL * this->_internal_tensor_vals_size(); + for (const auto& msg : this->tensor_vals_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.GraphProto graphs = 17; + total_size += 2UL * this->_internal_graphs_size(); + for (const auto& msg : this->graphs_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.ValueProto values = 18; + total_size += 2UL * this->_internal_values_size(); + for (const auto& msg : this->values_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.NamedValueProto dict_val = 19; + total_size += 2UL * this->_internal_dict_val_size(); + for (const auto& msg : this->dict_val_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x000000ffu) { + // optional string str_val = 7; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_str_val()); + } + + // optional .mindspore.irpb.TensorProto tensor_val = 8; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *tensor_val_); + } + + // optional .mindspore.irpb.GraphProto graph = 9; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *graph_); + } + + // optional .mindspore.irpb.TypeProto type_val = 20; + if (cached_has_bits & 0x00000008u) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *type_val_); + } + + // optional .mindspore.irpb.DataType dtype = 1; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_dtype()); + } + + // optional bool bool_val = 2; + if (cached_has_bits & 0x00000020u) { + total_size += 1 + 1; + } + + // optional int64 int_val = 3; + if (cached_has_bits & 0x00000040u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_int_val()); + } + + // optional uint64 uint_val = 4; + if (cached_has_bits & 0x00000080u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt64Size( + this->_internal_uint_val()); + } + + } + if (cached_has_bits & 0x00000300u) { + // optional double double_val = 6; + if (cached_has_bits & 0x00000100u) { + total_size += 1 + 8; + } + + // optional float float_val = 5; + if (cached_has_bits & 0x00000200u) { + total_size += 1 + 4; + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ValueProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.ValueProto) + GOOGLE_DCHECK_NE(&from, this); + const ValueProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.ValueProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.ValueProto) + MergeFrom(*source); + } +} + +void ValueProto::MergeFrom(const ValueProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.ValueProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + bool_vals_.MergeFrom(from.bool_vals_); + int_vals_.MergeFrom(from.int_vals_); + uint_vals_.MergeFrom(from.uint_vals_); + float_vals_.MergeFrom(from.float_vals_); + double_vals_.MergeFrom(from.double_vals_); + str_vals_.MergeFrom(from.str_vals_); + tensor_vals_.MergeFrom(from.tensor_vals_); + graphs_.MergeFrom(from.graphs_); + values_.MergeFrom(from.values_); + dict_val_.MergeFrom(from.dict_val_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x000000ffu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_str_val(from._internal_str_val()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_tensor_val()->::mindspore::irpb::TensorProto::MergeFrom(from._internal_tensor_val()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_graph()->::mindspore::irpb::GraphProto::MergeFrom(from._internal_graph()); + } + if (cached_has_bits & 0x00000008u) { + _internal_mutable_type_val()->::mindspore::irpb::TypeProto::MergeFrom(from._internal_type_val()); + } + if (cached_has_bits & 0x00000010u) { + dtype_ = from.dtype_; + } + if (cached_has_bits & 0x00000020u) { + bool_val_ = from.bool_val_; + } + if (cached_has_bits & 0x00000040u) { + int_val_ = from.int_val_; + } + if (cached_has_bits & 0x00000080u) { + uint_val_ = from.uint_val_; + } + _has_bits_[0] |= cached_has_bits; + } + if (cached_has_bits & 0x00000300u) { + if (cached_has_bits & 0x00000100u) { + double_val_ = from.double_val_; + } + if (cached_has_bits & 0x00000200u) { + float_val_ = from.float_val_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void ValueProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.ValueProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ValueProto::CopyFrom(const ValueProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.ValueProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ValueProto::IsInitialized() const { + return true; +} + +void ValueProto::InternalSwap(ValueProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + bool_vals_.InternalSwap(&other->bool_vals_); + int_vals_.InternalSwap(&other->int_vals_); + uint_vals_.InternalSwap(&other->uint_vals_); + float_vals_.InternalSwap(&other->float_vals_); + double_vals_.InternalSwap(&other->double_vals_); + str_vals_.InternalSwap(&other->str_vals_); + tensor_vals_.InternalSwap(&other->tensor_vals_); + graphs_.InternalSwap(&other->graphs_); + values_.InternalSwap(&other->values_); + dict_val_.InternalSwap(&other->dict_val_); + str_val_.Swap(&other->str_val_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(ValueProto, float_val_) + + sizeof(ValueProto::float_val_) + - PROTOBUF_FIELD_OFFSET(ValueProto, tensor_val_)>( + reinterpret_cast(&tensor_val_), + reinterpret_cast(&other->tensor_val_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ValueProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void AttributeProto::InitAsDefaultInstance() { + ::mindspore::irpb::_AttributeProto_default_instance_._instance.get_mutable()->value_ = const_cast< ::mindspore::irpb::ValueProto*>( + ::mindspore::irpb::ValueProto::internal_default_instance()); +} +class AttributeProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::ValueProto& value(const AttributeProto* msg); + static void set_has_value(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +const ::mindspore::irpb::ValueProto& +AttributeProto::_Internal::value(const AttributeProto* msg) { + return *msg->value_; +} +AttributeProto::AttributeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.AttributeProto) +} +AttributeProto::AttributeProto(const AttributeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + if (from._internal_has_value()) { + value_ = new ::mindspore::irpb::ValueProto(*from.value_); + } else { + value_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.AttributeProto) +} + +void AttributeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + value_ = nullptr; +} + +AttributeProto::~AttributeProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.AttributeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void AttributeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete value_; +} + +void AttributeProto::ArenaDtor(void* object) { + AttributeProto* _this = reinterpret_cast< AttributeProto* >(object); + (void)_this; +} +void AttributeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void AttributeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const AttributeProto& AttributeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void AttributeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.AttributeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(value_ != nullptr); + value_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* AttributeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.AttributeProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.ValueProto value = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_value(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* AttributeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.AttributeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.AttributeProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional .mindspore.irpb.ValueProto value = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::value(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.AttributeProto) + return target; +} + +size_t AttributeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.AttributeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional .mindspore.irpb.ValueProto value = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void AttributeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.AttributeProto) + GOOGLE_DCHECK_NE(&from, this); + const AttributeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.AttributeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.AttributeProto) + MergeFrom(*source); + } +} + +void AttributeProto::MergeFrom(const AttributeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.AttributeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_value()->::mindspore::irpb::ValueProto::MergeFrom(from._internal_value()); + } + } +} + +void AttributeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.AttributeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void AttributeProto::CopyFrom(const AttributeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.AttributeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool AttributeProto::IsInitialized() const { + return true; +} + +void AttributeProto::InternalSwap(AttributeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(value_, other->value_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata AttributeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void NamedValueProto::InitAsDefaultInstance() { + ::mindspore::irpb::_NamedValueProto_default_instance_._instance.get_mutable()->value_ = const_cast< ::mindspore::irpb::ValueProto*>( + ::mindspore::irpb::ValueProto::internal_default_instance()); +} +class NamedValueProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_key(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::ValueProto& value(const NamedValueProto* msg); + static void set_has_value(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +const ::mindspore::irpb::ValueProto& +NamedValueProto::_Internal::value(const NamedValueProto* msg) { + return *msg->value_; +} +NamedValueProto::NamedValueProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.NamedValueProto) +} +NamedValueProto::NamedValueProto(const NamedValueProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + key_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_key()) { + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_key(), + GetArena()); + } + if (from._internal_has_value()) { + value_ = new ::mindspore::irpb::ValueProto(*from.value_); + } else { + value_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.NamedValueProto) +} + +void NamedValueProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + key_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + value_ = nullptr; +} + +NamedValueProto::~NamedValueProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.NamedValueProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void NamedValueProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + key_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete value_; +} + +void NamedValueProto::ArenaDtor(void* object) { + NamedValueProto* _this = reinterpret_cast< NamedValueProto* >(object); + (void)_this; +} +void NamedValueProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void NamedValueProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const NamedValueProto& NamedValueProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void NamedValueProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.NamedValueProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + key_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(value_ != nullptr); + value_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* NamedValueProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string key = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_key(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.NamedValueProto.key"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.ValueProto value = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_value(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* NamedValueProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.NamedValueProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string key = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_key().data(), static_cast(this->_internal_key().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.NamedValueProto.key"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_key(), target); + } + + // optional .mindspore.irpb.ValueProto value = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::value(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.NamedValueProto) + return target; +} + +size_t NamedValueProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.NamedValueProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string key = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_key()); + } + + // optional .mindspore.irpb.ValueProto value = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void NamedValueProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.NamedValueProto) + GOOGLE_DCHECK_NE(&from, this); + const NamedValueProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.NamedValueProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.NamedValueProto) + MergeFrom(*source); + } +} + +void NamedValueProto::MergeFrom(const NamedValueProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.NamedValueProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_key(from._internal_key()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_value()->::mindspore::irpb::ValueProto::MergeFrom(from._internal_value()); + } + } +} + +void NamedValueProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.NamedValueProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void NamedValueProto::CopyFrom(const NamedValueProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.NamedValueProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool NamedValueProto::IsInitialized() const { + return true; +} + +void NamedValueProto::InternalSwap(NamedValueProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + key_.Swap(&other->key_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(value_, other->value_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata NamedValueProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorShapeProto_Dimension::InitAsDefaultInstance() { +} +class TensorShapeProto_Dimension::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_size(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +TensorShapeProto_Dimension::TensorShapeProto_Dimension(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.TensorShapeProto.Dimension) +} +TensorShapeProto_Dimension::TensorShapeProto_Dimension(const TensorShapeProto_Dimension& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + size_ = from.size_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.TensorShapeProto.Dimension) +} + +void TensorShapeProto_Dimension::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + size_ = PROTOBUF_LONGLONG(0); +} + +TensorShapeProto_Dimension::~TensorShapeProto_Dimension() { + // @@protoc_insertion_point(destructor:mindspore.irpb.TensorShapeProto.Dimension) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorShapeProto_Dimension::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void TensorShapeProto_Dimension::ArenaDtor(void* object) { + TensorShapeProto_Dimension* _this = reinterpret_cast< TensorShapeProto_Dimension* >(object); + (void)_this; +} +void TensorShapeProto_Dimension::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorShapeProto_Dimension::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorShapeProto_Dimension& TensorShapeProto_Dimension::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorShapeProto_Dimension_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void TensorShapeProto_Dimension::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.TensorShapeProto.Dimension) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + size_ = PROTOBUF_LONGLONG(0); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorShapeProto_Dimension::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional int64 size = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_size(&has_bits); + size_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.TensorShapeProto.Dimension.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorShapeProto_Dimension::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.TensorShapeProto.Dimension) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional int64 size = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_size(), target); + } + + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.TensorShapeProto.Dimension.name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_name(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.TensorShapeProto.Dimension) + return target; +} + +size_t TensorShapeProto_Dimension::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.TensorShapeProto.Dimension) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional int64 size = 1; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_size()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorShapeProto_Dimension::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.TensorShapeProto.Dimension) + GOOGLE_DCHECK_NE(&from, this); + const TensorShapeProto_Dimension* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.TensorShapeProto.Dimension) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.TensorShapeProto.Dimension) + MergeFrom(*source); + } +} + +void TensorShapeProto_Dimension::MergeFrom(const TensorShapeProto_Dimension& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.TensorShapeProto.Dimension) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + size_ = from.size_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TensorShapeProto_Dimension::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.TensorShapeProto.Dimension) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorShapeProto_Dimension::CopyFrom(const TensorShapeProto_Dimension& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.TensorShapeProto.Dimension) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorShapeProto_Dimension::IsInitialized() const { + return true; +} + +void TensorShapeProto_Dimension::InternalSwap(TensorShapeProto_Dimension* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(size_, other->size_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorShapeProto_Dimension::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorShapeProto::InitAsDefaultInstance() { +} +class TensorShapeProto::_Internal { + public: +}; + +TensorShapeProto::TensorShapeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + dim_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.TensorShapeProto) +} +TensorShapeProto::TensorShapeProto(const TensorShapeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + dim_(from.dim_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.TensorShapeProto) +} + +void TensorShapeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto.base); +} + +TensorShapeProto::~TensorShapeProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.TensorShapeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorShapeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void TensorShapeProto::ArenaDtor(void* object) { + TensorShapeProto* _this = reinterpret_cast< TensorShapeProto* >(object); + (void)_this; +} +void TensorShapeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorShapeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorShapeProto& TensorShapeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorShapeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void TensorShapeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + dim_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorShapeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.TensorShapeProto.Dimension dim = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_dim(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorShapeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.TensorShapeProto.Dimension dim = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_dim_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_dim(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.TensorShapeProto) + return target; +} + +size_t TensorShapeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.TensorShapeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.TensorShapeProto.Dimension dim = 1; + total_size += 1UL * this->_internal_dim_size(); + for (const auto& msg : this->dim_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorShapeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.TensorShapeProto) + GOOGLE_DCHECK_NE(&from, this); + const TensorShapeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.TensorShapeProto) + MergeFrom(*source); + } +} + +void TensorShapeProto::MergeFrom(const TensorShapeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.TensorShapeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + dim_.MergeFrom(from.dim_); +} + +void TensorShapeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.TensorShapeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorShapeProto::CopyFrom(const TensorShapeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.TensorShapeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorShapeProto::IsInitialized() const { + return true; +} + +void TensorShapeProto::InternalSwap(TensorShapeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + dim_.InternalSwap(&other->dim_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorShapeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_Tensor::InitAsDefaultInstance() { + ::mindspore::irpb::_TypeProto_Tensor_default_instance_._instance.get_mutable()->shape_ = const_cast< ::mindspore::irpb::TensorShapeProto*>( + ::mindspore::irpb::TensorShapeProto::internal_default_instance()); +} +class TypeProto_Tensor::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_elem_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::TensorShapeProto& shape(const TypeProto_Tensor* msg); + static void set_has_shape(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::mindspore::irpb::TensorShapeProto& +TypeProto_Tensor::_Internal::shape(const TypeProto_Tensor* msg) { + return *msg->shape_; +} +TypeProto_Tensor::TypeProto_Tensor(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.TypeProto.Tensor) +} +TypeProto_Tensor::TypeProto_Tensor(const TypeProto_Tensor& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_shape()) { + shape_ = new ::mindspore::irpb::TensorShapeProto(*from.shape_); + } else { + shape_ = nullptr; + } + elem_type_ = from.elem_type_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.TypeProto.Tensor) +} + +void TypeProto_Tensor::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto.base); + ::memset(&shape_, 0, static_cast( + reinterpret_cast(&elem_type_) - + reinterpret_cast(&shape_)) + sizeof(elem_type_)); +} + +TypeProto_Tensor::~TypeProto_Tensor() { + // @@protoc_insertion_point(destructor:mindspore.irpb.TypeProto.Tensor) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_Tensor::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete shape_; +} + +void TypeProto_Tensor::ArenaDtor(void* object) { + TypeProto_Tensor* _this = reinterpret_cast< TypeProto_Tensor* >(object); + (void)_this; +} +void TypeProto_Tensor::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_Tensor::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_Tensor& TypeProto_Tensor::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_Tensor_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_Tensor::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.TypeProto.Tensor) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(shape_ != nullptr); + shape_->Clear(); + } + elem_type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_Tensor::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .mindspore.irpb.DataType elem_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::mindspore::irpb::DataType_IsValid(val))) { + _internal_set_elem_type(static_cast<::mindspore::irpb::DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(1, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TensorShapeProto shape = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_shape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_Tensor::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.TypeProto.Tensor) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.DataType elem_type = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_elem_type(), target); + } + + // optional .mindspore.irpb.TensorShapeProto shape = 2; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::shape(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.TypeProto.Tensor) + return target; +} + +size_t TypeProto_Tensor::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.TypeProto.Tensor) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional .mindspore.irpb.TensorShapeProto shape = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *shape_); + } + + // optional .mindspore.irpb.DataType elem_type = 1; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_elem_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_Tensor::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.TypeProto.Tensor) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_Tensor* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.TypeProto.Tensor) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.TypeProto.Tensor) + MergeFrom(*source); + } +} + +void TypeProto_Tensor::MergeFrom(const TypeProto_Tensor& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.TypeProto.Tensor) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_mutable_shape()->::mindspore::irpb::TensorShapeProto::MergeFrom(from._internal_shape()); + } + if (cached_has_bits & 0x00000002u) { + elem_type_ = from.elem_type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TypeProto_Tensor::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.TypeProto.Tensor) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_Tensor::CopyFrom(const TypeProto_Tensor& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.TypeProto.Tensor) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_Tensor::IsInitialized() const { + return true; +} + +void TypeProto_Tensor::InternalSwap(TypeProto_Tensor* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TypeProto_Tensor, elem_type_) + + sizeof(TypeProto_Tensor::elem_type_) + - PROTOBUF_FIELD_OFFSET(TypeProto_Tensor, shape_)>( + reinterpret_cast(&shape_), + reinterpret_cast(&other->shape_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_Tensor::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_Sequence::InitAsDefaultInstance() { +} +class TypeProto_Sequence::_Internal { + public: +}; + +TypeProto_Sequence::TypeProto_Sequence(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + elem_types_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.TypeProto.Sequence) +} +TypeProto_Sequence::TypeProto_Sequence(const TypeProto_Sequence& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + elem_types_(from.elem_types_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.TypeProto.Sequence) +} + +void TypeProto_Sequence::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base); +} + +TypeProto_Sequence::~TypeProto_Sequence() { + // @@protoc_insertion_point(destructor:mindspore.irpb.TypeProto.Sequence) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_Sequence::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void TypeProto_Sequence::ArenaDtor(void* object) { + TypeProto_Sequence* _this = reinterpret_cast< TypeProto_Sequence* >(object); + (void)_this; +} +void TypeProto_Sequence::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_Sequence::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_Sequence& TypeProto_Sequence::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_Sequence::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.TypeProto.Sequence) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + elem_types_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_Sequence::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.TypeProto elem_types = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_elem_types(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_Sequence::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.TypeProto.Sequence) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.TypeProto elem_types = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_elem_types_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_elem_types(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.TypeProto.Sequence) + return target; +} + +size_t TypeProto_Sequence::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.TypeProto.Sequence) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.TypeProto elem_types = 1; + total_size += 1UL * this->_internal_elem_types_size(); + for (const auto& msg : this->elem_types_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_Sequence::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.TypeProto.Sequence) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_Sequence* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.TypeProto.Sequence) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.TypeProto.Sequence) + MergeFrom(*source); + } +} + +void TypeProto_Sequence::MergeFrom(const TypeProto_Sequence& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.TypeProto.Sequence) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + elem_types_.MergeFrom(from.elem_types_); +} + +void TypeProto_Sequence::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.TypeProto.Sequence) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_Sequence::CopyFrom(const TypeProto_Sequence& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.TypeProto.Sequence) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_Sequence::IsInitialized() const { + return true; +} + +void TypeProto_Sequence::InternalSwap(TypeProto_Sequence* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + elem_types_.InternalSwap(&other->elem_types_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_Sequence::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto::InitAsDefaultInstance() { + ::mindspore::irpb::_TypeProto_default_instance_.tensor_type_ = const_cast< ::mindspore::irpb::TypeProto_Tensor*>( + ::mindspore::irpb::TypeProto_Tensor::internal_default_instance()); + ::mindspore::irpb::_TypeProto_default_instance_.sequence_type_ = const_cast< ::mindspore::irpb::TypeProto_Sequence*>( + ::mindspore::irpb::TypeProto_Sequence::internal_default_instance()); +} +class TypeProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_data_type(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::TypeProto_Tensor& tensor_type(const TypeProto* msg); + static const ::mindspore::irpb::TypeProto_Sequence& sequence_type(const TypeProto* msg); +}; + +const ::mindspore::irpb::TypeProto_Tensor& +TypeProto::_Internal::tensor_type(const TypeProto* msg) { + return *msg->value_.tensor_type_; +} +const ::mindspore::irpb::TypeProto_Sequence& +TypeProto::_Internal::sequence_type(const TypeProto* msg) { + return *msg->value_.sequence_type_; +} +void TypeProto::set_allocated_tensor_type(::mindspore::irpb::TypeProto_Tensor* tensor_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (tensor_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(tensor_type); + if (message_arena != submessage_arena) { + tensor_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tensor_type, submessage_arena); + } + set_has_tensor_type(); + value_.tensor_type_ = tensor_type; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.TypeProto.tensor_type) +} +void TypeProto::set_allocated_sequence_type(::mindspore::irpb::TypeProto_Sequence* sequence_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (sequence_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(sequence_type); + if (message_arena != submessage_arena) { + sequence_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, sequence_type, submessage_arena); + } + set_has_sequence_type(); + value_.sequence_type_ = sequence_type; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.TypeProto.sequence_type) +} +TypeProto::TypeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.TypeProto) +} +TypeProto::TypeProto(const TypeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + data_type_ = from.data_type_; + clear_has_value(); + switch (from.value_case()) { + case kTensorType: { + _internal_mutable_tensor_type()->::mindspore::irpb::TypeProto_Tensor::MergeFrom(from._internal_tensor_type()); + break; + } + case kSequenceType: { + _internal_mutable_sequence_type()->::mindspore::irpb::TypeProto_Sequence::MergeFrom(from._internal_sequence_type()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.TypeProto) +} + +void TypeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base); + data_type_ = 0; + clear_has_value(); +} + +TypeProto::~TypeProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.TypeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (has_value()) { + clear_value(); + } +} + +void TypeProto::ArenaDtor(void* object) { + TypeProto* _this = reinterpret_cast< TypeProto* >(object); + (void)_this; +} +void TypeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto& TypeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto::clear_value() { +// @@protoc_insertion_point(one_of_clear_start:mindspore.irpb.TypeProto) + switch (value_case()) { + case kTensorType: { + if (GetArena() == nullptr) { + delete value_.tensor_type_; + } + break; + } + case kSequenceType: { + if (GetArena() == nullptr) { + delete value_.sequence_type_; + } + break; + } + case VALUE_NOT_SET: { + break; + } + } + _oneof_case_[0] = VALUE_NOT_SET; +} + + +void TypeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.TypeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + data_type_ = 0; + clear_value(); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .mindspore.irpb.DataType data_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::mindspore::irpb::DataType_IsValid(val))) { + _internal_set_data_type(static_cast<::mindspore::irpb::DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(1, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // .mindspore.irpb.TypeProto.Tensor tensor_type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_tensor_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.TypeProto.Sequence sequence_type = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ctx->ParseMessage(_internal_mutable_sequence_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.TypeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.DataType data_type = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_data_type(), target); + } + + switch (value_case()) { + case kTensorType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::tensor_type(this), target, stream); + break; + } + case kSequenceType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 3, _Internal::sequence_type(this), target, stream); + break; + } + default: ; + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.TypeProto) + return target; +} + +size_t TypeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.TypeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // optional .mindspore.irpb.DataType data_type = 1; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_data_type()); + } + + switch (value_case()) { + // .mindspore.irpb.TypeProto.Tensor tensor_type = 2; + case kTensorType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.tensor_type_); + break; + } + // .mindspore.irpb.TypeProto.Sequence sequence_type = 3; + case kSequenceType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.sequence_type_); + break; + } + case VALUE_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.TypeProto) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.TypeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.TypeProto) + MergeFrom(*source); + } +} + +void TypeProto::MergeFrom(const TypeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.TypeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from._internal_has_data_type()) { + _internal_set_data_type(from._internal_data_type()); + } + switch (from.value_case()) { + case kTensorType: { + _internal_mutable_tensor_type()->::mindspore::irpb::TypeProto_Tensor::MergeFrom(from._internal_tensor_type()); + break; + } + case kSequenceType: { + _internal_mutable_sequence_type()->::mindspore::irpb::TypeProto_Sequence::MergeFrom(from._internal_sequence_type()); + break; + } + case VALUE_NOT_SET: { + break; + } + } +} + +void TypeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.TypeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto::CopyFrom(const TypeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.TypeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto::IsInitialized() const { + return true; +} + +void TypeProto::InternalSwap(TypeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + swap(data_type_, other->data_type_); + swap(value_, other->value_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void ParameterProto::InitAsDefaultInstance() { + ::mindspore::irpb::_ParameterProto_default_instance_._instance.get_mutable()->type_ = const_cast< ::mindspore::irpb::TypeProto*>( + ::mindspore::irpb::TypeProto::internal_default_instance()); + ::mindspore::irpb::_ParameterProto_default_instance_._instance.get_mutable()->default_val_ = const_cast< ::mindspore::irpb::ValueProto*>( + ::mindspore::irpb::ValueProto::internal_default_instance()); +} +class ParameterProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::TypeProto& type(const ParameterProto* msg); + static void set_has_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::ValueProto& default_val(const ParameterProto* msg); + static void set_has_default_val(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +const ::mindspore::irpb::TypeProto& +ParameterProto::_Internal::type(const ParameterProto* msg) { + return *msg->type_; +} +const ::mindspore::irpb::ValueProto& +ParameterProto::_Internal::default_val(const ParameterProto* msg) { + return *msg->default_val_; +} +ParameterProto::ParameterProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.ParameterProto) +} +ParameterProto::ParameterProto(const ParameterProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + if (from._internal_has_type()) { + type_ = new ::mindspore::irpb::TypeProto(*from.type_); + } else { + type_ = nullptr; + } + if (from._internal_has_default_val()) { + default_val_ = new ::mindspore::irpb::ValueProto(*from.default_val_); + } else { + default_val_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.ParameterProto) +} + +void ParameterProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&type_, 0, static_cast( + reinterpret_cast(&default_val_) - + reinterpret_cast(&type_)) + sizeof(default_val_)); +} + +ParameterProto::~ParameterProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.ParameterProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ParameterProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete type_; + if (this != internal_default_instance()) delete default_val_; +} + +void ParameterProto::ArenaDtor(void* object) { + ParameterProto* _this = reinterpret_cast< ParameterProto* >(object); + (void)_this; +} +void ParameterProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ParameterProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ParameterProto& ParameterProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void ParameterProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.ParameterProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(type_ != nullptr); + type_->Clear(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(default_val_ != nullptr); + default_val_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ParameterProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.ParameterProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TypeProto type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.ValueProto default_val = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ctx->ParseMessage(_internal_mutable_default_val(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ParameterProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.ParameterProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.ParameterProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional .mindspore.irpb.TypeProto type = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::type(this), target, stream); + } + + // optional .mindspore.irpb.ValueProto default_val = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 3, _Internal::default_val(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.ParameterProto) + return target; +} + +size_t ParameterProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.ParameterProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional .mindspore.irpb.TypeProto type = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *type_); + } + + // optional .mindspore.irpb.ValueProto default_val = 3; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *default_val_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ParameterProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.ParameterProto) + GOOGLE_DCHECK_NE(&from, this); + const ParameterProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.ParameterProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.ParameterProto) + MergeFrom(*source); + } +} + +void ParameterProto::MergeFrom(const ParameterProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.ParameterProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_type()->::mindspore::irpb::TypeProto::MergeFrom(from._internal_type()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_default_val()->::mindspore::irpb::ValueProto::MergeFrom(from._internal_default_val()); + } + } +} + +void ParameterProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.ParameterProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ParameterProto::CopyFrom(const ParameterProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.ParameterProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ParameterProto::IsInitialized() const { + return true; +} + +void ParameterProto::InternalSwap(ParameterProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(ParameterProto, default_val_) + + sizeof(ParameterProto::default_val_) + - PROTOBUF_FIELD_OFFSET(ParameterProto, type_)>( + reinterpret_cast(&type_), + reinterpret_cast(&other->type_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ParameterProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void OutputProto::InitAsDefaultInstance() { + ::mindspore::irpb::_OutputProto_default_instance_._instance.get_mutable()->type_ = const_cast< ::mindspore::irpb::TypeProto*>( + ::mindspore::irpb::TypeProto::internal_default_instance()); +} +class OutputProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::TypeProto& type(const OutputProto* msg); + static void set_has_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +const ::mindspore::irpb::TypeProto& +OutputProto::_Internal::type(const OutputProto* msg) { + return *msg->type_; +} +OutputProto::OutputProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.OutputProto) +} +OutputProto::OutputProto(const OutputProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + if (from._internal_has_type()) { + type_ = new ::mindspore::irpb::TypeProto(*from.type_); + } else { + type_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.OutputProto) +} + +void OutputProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_OutputProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + type_ = nullptr; +} + +OutputProto::~OutputProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.OutputProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void OutputProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete type_; +} + +void OutputProto::ArenaDtor(void* object) { + OutputProto* _this = reinterpret_cast< OutputProto* >(object); + (void)_this; +} +void OutputProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void OutputProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const OutputProto& OutputProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_OutputProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void OutputProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.OutputProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(type_ != nullptr); + type_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* OutputProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.OutputProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TypeProto type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* OutputProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.OutputProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.OutputProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional .mindspore.irpb.TypeProto type = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::type(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.OutputProto) + return target; +} + +size_t OutputProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.OutputProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional .mindspore.irpb.TypeProto type = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *type_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void OutputProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.OutputProto) + GOOGLE_DCHECK_NE(&from, this); + const OutputProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.OutputProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.OutputProto) + MergeFrom(*source); + } +} + +void OutputProto::MergeFrom(const OutputProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.OutputProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_type()->::mindspore::irpb::TypeProto::MergeFrom(from._internal_type()); + } + } +} + +void OutputProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.OutputProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void OutputProto::CopyFrom(const OutputProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.OutputProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool OutputProto::IsInitialized() const { + return true; +} + +void OutputProto::InternalSwap(OutputProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(type_, other->type_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata OutputProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void InputProto::InitAsDefaultInstance() { +} +class InputProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +InputProto::InputProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.InputProto) +} +InputProto::InputProto(const InputProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + type_ = from.type_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.InputProto) +} + +void InputProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_InputProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + type_ = 0; +} + +InputProto::~InputProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.InputProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void InputProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void InputProto::ArenaDtor(void* object) { + InputProto* _this = reinterpret_cast< InputProto* >(object); + (void)_this; +} +void InputProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void InputProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const InputProto& InputProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_InputProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void InputProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.InputProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* InputProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.InputProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.InputProto.EdgeType type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::mindspore::irpb::InputProto_EdgeType_IsValid(val))) { + _internal_set_type(static_cast<::mindspore::irpb::InputProto_EdgeType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(2, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* InputProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.InputProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.InputProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional .mindspore.irpb.InputProto.EdgeType type = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 2, this->_internal_type(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.InputProto) + return target; +} + +size_t InputProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.InputProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional .mindspore.irpb.InputProto.EdgeType type = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void InputProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.InputProto) + GOOGLE_DCHECK_NE(&from, this); + const InputProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.InputProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.InputProto) + MergeFrom(*source); + } +} + +void InputProto::MergeFrom(const InputProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.InputProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + type_ = from.type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void InputProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.InputProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void InputProto::CopyFrom(const InputProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.InputProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool InputProto::IsInitialized() const { + return true; +} + +void InputProto::InternalSwap(InputProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(type_, other->type_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata InputProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void NodeProto::InitAsDefaultInstance() { + ::mindspore::irpb::_NodeProto_default_instance_._instance.get_mutable()->output_type_ = const_cast< ::mindspore::irpb::TypeProto*>( + ::mindspore::irpb::TypeProto::internal_default_instance()); +} +class NodeProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_op_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_scope(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static const ::mindspore::irpb::TypeProto& output_type(const NodeProto* msg); + static void set_has_output_type(HasBits* has_bits) { + (*has_bits)[0] |= 32u; + } + static void set_has_output_i(HasBits* has_bits) { + (*has_bits)[0] |= 64u; + } + static void set_has_full_name(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_instance_name(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } +}; + +const ::mindspore::irpb::TypeProto& +NodeProto::_Internal::output_type(const NodeProto* msg) { + return *msg->output_type_; +} +NodeProto::NodeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + input_(arena), + attribute_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.NodeProto) +} +NodeProto::NodeProto(const NodeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + input_(from.input_), + attribute_(from.attribute_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + op_type_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_op_type()) { + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_op_type(), + GetArena()); + } + scope_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_scope()) { + scope_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_scope(), + GetArena()); + } + full_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_full_name()) { + full_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_full_name(), + GetArena()); + } + instance_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_instance_name()) { + instance_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_instance_name(), + GetArena()); + } + if (from._internal_has_output_type()) { + output_type_ = new ::mindspore::irpb::TypeProto(*from.output_type_); + } else { + output_type_ = nullptr; + } + output_i_ = from.output_i_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.NodeProto) +} + +void NodeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + op_type_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + scope_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + full_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + instance_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&output_type_, 0, static_cast( + reinterpret_cast(&output_i_) - + reinterpret_cast(&output_type_)) + sizeof(output_i_)); +} + +NodeProto::~NodeProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.NodeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void NodeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + op_type_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + scope_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + full_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + instance_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete output_type_; +} + +void NodeProto::ArenaDtor(void* object) { + NodeProto* _this = reinterpret_cast< NodeProto* >(object); + (void)_this; +} +void NodeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void NodeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const NodeProto& NodeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void NodeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.NodeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + input_.Clear(); + attribute_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000003fu) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + op_type_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + scope_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000008u) { + full_name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000010u) { + instance_name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000020u) { + GOOGLE_DCHECK(output_type_ != nullptr); + output_type_->Clear(); + } + } + output_i_ = PROTOBUF_ULONGLONG(0); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* NodeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.InputProto input = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_input(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // optional string name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.NodeProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string op_type = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_op_type(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.NodeProto.op_type"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string scope = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_scope(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.NodeProto.scope"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.AttributeProto attribute = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_attribute(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<42>(ptr)); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TypeProto output_type = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ctx->ParseMessage(_internal_mutable_output_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional uint64 output_i = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 56)) { + _Internal::set_has_output_i(&has_bits); + output_i_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string full_name = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + auto str = _internal_mutable_full_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.NodeProto.full_name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string instance_name = 10; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + auto str = _internal_mutable_instance_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.NodeProto.instance_name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* NodeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.NodeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.InputProto input = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_input_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_input(i), target, stream); + } + + cached_has_bits = _has_bits_[0]; + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.NodeProto.name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_name(), target); + } + + // optional string op_type = 3; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_op_type().data(), static_cast(this->_internal_op_type().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.NodeProto.op_type"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_op_type(), target); + } + + // optional string scope = 4; + if (cached_has_bits & 0x00000004u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_scope().data(), static_cast(this->_internal_scope().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.NodeProto.scope"); + target = stream->WriteStringMaybeAliased( + 4, this->_internal_scope(), target); + } + + // repeated .mindspore.irpb.AttributeProto attribute = 5; + for (unsigned int i = 0, + n = static_cast(this->_internal_attribute_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(5, this->_internal_attribute(i), target, stream); + } + + // optional .mindspore.irpb.TypeProto output_type = 6; + if (cached_has_bits & 0x00000020u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 6, _Internal::output_type(this), target, stream); + } + + // optional uint64 output_i = 7; + if (cached_has_bits & 0x00000040u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt64ToArray(7, this->_internal_output_i(), target); + } + + // optional string full_name = 8; + if (cached_has_bits & 0x00000008u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_full_name().data(), static_cast(this->_internal_full_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.NodeProto.full_name"); + target = stream->WriteStringMaybeAliased( + 8, this->_internal_full_name(), target); + } + + // optional string instance_name = 10; + if (cached_has_bits & 0x00000010u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_instance_name().data(), static_cast(this->_internal_instance_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.NodeProto.instance_name"); + target = stream->WriteStringMaybeAliased( + 10, this->_internal_instance_name(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.NodeProto) + return target; +} + +size_t NodeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.NodeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.InputProto input = 1; + total_size += 1UL * this->_internal_input_size(); + for (const auto& msg : this->input_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.AttributeProto attribute = 5; + total_size += 1UL * this->_internal_attribute_size(); + for (const auto& msg : this->attribute_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional string op_type = 3; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_op_type()); + } + + // optional string scope = 4; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_scope()); + } + + // optional string full_name = 8; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_full_name()); + } + + // optional string instance_name = 10; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_instance_name()); + } + + // optional .mindspore.irpb.TypeProto output_type = 6; + if (cached_has_bits & 0x00000020u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *output_type_); + } + + // optional uint64 output_i = 7; + if (cached_has_bits & 0x00000040u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt64Size( + this->_internal_output_i()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void NodeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.NodeProto) + GOOGLE_DCHECK_NE(&from, this); + const NodeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.NodeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.NodeProto) + MergeFrom(*source); + } +} + +void NodeProto::MergeFrom(const NodeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.NodeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + input_.MergeFrom(from.input_); + attribute_.MergeFrom(from.attribute_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_op_type(from._internal_op_type()); + } + if (cached_has_bits & 0x00000004u) { + _internal_set_scope(from._internal_scope()); + } + if (cached_has_bits & 0x00000008u) { + _internal_set_full_name(from._internal_full_name()); + } + if (cached_has_bits & 0x00000010u) { + _internal_set_instance_name(from._internal_instance_name()); + } + if (cached_has_bits & 0x00000020u) { + _internal_mutable_output_type()->::mindspore::irpb::TypeProto::MergeFrom(from._internal_output_type()); + } + if (cached_has_bits & 0x00000040u) { + output_i_ = from.output_i_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void NodeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.NodeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void NodeProto::CopyFrom(const NodeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.NodeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool NodeProto::IsInitialized() const { + return true; +} + +void NodeProto::InternalSwap(NodeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + input_.InternalSwap(&other->input_); + attribute_.InternalSwap(&other->attribute_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + op_type_.Swap(&other->op_type_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + scope_.Swap(&other->scope_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + full_name_.Swap(&other->full_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + instance_name_.Swap(&other->instance_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(NodeProto, output_i_) + + sizeof(NodeProto::output_i_) + - PROTOBUF_FIELD_OFFSET(NodeProto, output_type_)>( + reinterpret_cast(&output_type_), + reinterpret_cast(&other->output_type_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata NodeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void ModelProto::InitAsDefaultInstance() { + ::mindspore::irpb::_ModelProto_default_instance_._instance.get_mutable()->graph_ = const_cast< ::mindspore::irpb::GraphProto*>( + ::mindspore::irpb::GraphProto::internal_default_instance()); + ::mindspore::irpb::_ModelProto_default_instance_._instance.get_mutable()->metadata_operators_ = const_cast< ::mindspore::irpb::OperatorSetProto*>( + ::mindspore::irpb::OperatorSetProto::internal_default_instance()); +} +class ModelProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_ir_version(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_domain(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_model_version(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } + static const ::mindspore::irpb::GraphProto& graph(const ModelProto* msg); + static void set_has_graph(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::OperatorSetProto& metadata_operators(const ModelProto* msg); + static void set_has_metadata_operators(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +const ::mindspore::irpb::GraphProto& +ModelProto::_Internal::graph(const ModelProto* msg) { + return *msg->graph_; +} +const ::mindspore::irpb::OperatorSetProto& +ModelProto::_Internal::metadata_operators(const ModelProto* msg) { + return *msg->metadata_operators_; +} +ModelProto::ModelProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.ModelProto) +} +ModelProto::ModelProto(const ModelProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_domain()) { + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_domain(), + GetArena()); + } + if (from._internal_has_graph()) { + graph_ = new ::mindspore::irpb::GraphProto(*from.graph_); + } else { + graph_ = nullptr; + } + if (from._internal_has_metadata_operators()) { + metadata_operators_ = new ::mindspore::irpb::OperatorSetProto(*from.metadata_operators_); + } else { + metadata_operators_ = nullptr; + } + ::memcpy(&ir_version_, &from.ir_version_, + static_cast(reinterpret_cast(&model_version_) - + reinterpret_cast(&ir_version_)) + sizeof(model_version_)); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.ModelProto) +} + +void ModelProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_ModelProto_mindspore_5fanf_5fir_2eproto.base); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&graph_, 0, static_cast( + reinterpret_cast(&model_version_) - + reinterpret_cast(&graph_)) + sizeof(model_version_)); +} + +ModelProto::~ModelProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.ModelProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ModelProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + domain_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete graph_; + if (this != internal_default_instance()) delete metadata_operators_; +} + +void ModelProto::ArenaDtor(void* object) { + ModelProto* _this = reinterpret_cast< ModelProto* >(object); + (void)_this; +} +void ModelProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ModelProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ModelProto& ModelProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_ModelProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void ModelProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.ModelProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + domain_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(graph_ != nullptr); + graph_->Clear(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(metadata_operators_ != nullptr); + metadata_operators_->Clear(); + } + } + if (cached_has_bits & 0x00000018u) { + ::memset(&ir_version_, 0, static_cast( + reinterpret_cast(&model_version_) - + reinterpret_cast(&ir_version_)) + sizeof(model_version_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ModelProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional int64 ir_version = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_ir_version(&has_bits); + ir_version_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string domain = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_domain(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.ModelProto.domain"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 model_version = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_model_version(&has_bits); + model_version_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.GraphProto graph = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_graph(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.OperatorSetProto metadata_operators = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ctx->ParseMessage(_internal_mutable_metadata_operators(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ModelProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.ModelProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional int64 ir_version = 1; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_ir_version(), target); + } + + // optional string domain = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_domain().data(), static_cast(this->_internal_domain().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.ModelProto.domain"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_domain(), target); + } + + // optional int64 model_version = 3; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->_internal_model_version(), target); + } + + // optional .mindspore.irpb.GraphProto graph = 4; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::graph(this), target, stream); + } + + // optional .mindspore.irpb.OperatorSetProto metadata_operators = 5; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 5, _Internal::metadata_operators(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.ModelProto) + return target; +} + +size_t ModelProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.ModelProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + // optional string domain = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_domain()); + } + + // optional .mindspore.irpb.GraphProto graph = 4; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *graph_); + } + + // optional .mindspore.irpb.OperatorSetProto metadata_operators = 5; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *metadata_operators_); + } + + // optional int64 ir_version = 1; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_ir_version()); + } + + // optional int64 model_version = 3; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_model_version()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ModelProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.ModelProto) + GOOGLE_DCHECK_NE(&from, this); + const ModelProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.ModelProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.ModelProto) + MergeFrom(*source); + } +} + +void ModelProto::MergeFrom(const ModelProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.ModelProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_domain(from._internal_domain()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_graph()->::mindspore::irpb::GraphProto::MergeFrom(from._internal_graph()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_metadata_operators()->::mindspore::irpb::OperatorSetProto::MergeFrom(from._internal_metadata_operators()); + } + if (cached_has_bits & 0x00000008u) { + ir_version_ = from.ir_version_; + } + if (cached_has_bits & 0x00000010u) { + model_version_ = from.model_version_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void ModelProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.ModelProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ModelProto::CopyFrom(const ModelProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.ModelProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ModelProto::IsInitialized() const { + return true; +} + +void ModelProto::InternalSwap(ModelProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + domain_.Swap(&other->domain_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(ModelProto, model_version_) + + sizeof(ModelProto::model_version_) + - PROTOBUF_FIELD_OFFSET(ModelProto, graph_)>( + reinterpret_cast(&graph_), + reinterpret_cast(&other->graph_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ModelProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void OperatorProto::InitAsDefaultInstance() { +} +class OperatorProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_config(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_obj_info(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +OperatorProto::OperatorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.OperatorProto) +} +OperatorProto::OperatorProto(const OperatorProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + config_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_config()) { + config_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_config(), + GetArena()); + } + obj_info_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_obj_info()) { + obj_info_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_obj_info(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.OperatorProto) +} + +void OperatorProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_OperatorProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + config_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + obj_info_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +OperatorProto::~OperatorProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.OperatorProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void OperatorProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + config_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + obj_info_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void OperatorProto::ArenaDtor(void* object) { + OperatorProto* _this = reinterpret_cast< OperatorProto* >(object); + (void)_this; +} +void OperatorProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void OperatorProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const OperatorProto& OperatorProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_OperatorProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void OperatorProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.OperatorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + config_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + obj_info_.ClearNonDefaultToEmpty(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* OperatorProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.OperatorProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional bytes config = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_config(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional bytes obj_info = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_obj_info(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* OperatorProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.OperatorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.OperatorProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional bytes config = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->WriteBytesMaybeAliased( + 2, this->_internal_config(), target); + } + + // optional bytes obj_info = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->WriteBytesMaybeAliased( + 3, this->_internal_obj_info(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.OperatorProto) + return target; +} + +size_t OperatorProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.OperatorProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional bytes config = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_config()); + } + + // optional bytes obj_info = 3; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_obj_info()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void OperatorProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.OperatorProto) + GOOGLE_DCHECK_NE(&from, this); + const OperatorProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.OperatorProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.OperatorProto) + MergeFrom(*source); + } +} + +void OperatorProto::MergeFrom(const OperatorProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.OperatorProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_config(from._internal_config()); + } + if (cached_has_bits & 0x00000004u) { + _internal_set_obj_info(from._internal_obj_info()); + } + } +} + +void OperatorProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.OperatorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void OperatorProto::CopyFrom(const OperatorProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.OperatorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool OperatorProto::IsInitialized() const { + return true; +} + +void OperatorProto::InternalSwap(OperatorProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + config_.Swap(&other->config_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + obj_info_.Swap(&other->obj_info_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata OperatorProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void OperatorSetProto::InitAsDefaultInstance() { +} +class OperatorSetProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_summary(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +OperatorSetProto::OperatorSetProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + operators_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.OperatorSetProto) +} +OperatorSetProto::OperatorSetProto(const OperatorSetProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + operators_(from.operators_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + summary_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_summary()) { + summary_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_summary(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.OperatorSetProto) +} + +void OperatorSetProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto.base); + summary_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +OperatorSetProto::~OperatorSetProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.OperatorSetProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void OperatorSetProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + summary_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void OperatorSetProto::ArenaDtor(void* object) { + OperatorSetProto* _this = reinterpret_cast< OperatorSetProto* >(object); + (void)_this; +} +void OperatorSetProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void OperatorSetProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const OperatorSetProto& OperatorSetProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_OperatorSetProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void OperatorSetProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.OperatorSetProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + operators_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + summary_.ClearNonDefaultToEmpty(); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* OperatorSetProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.OperatorProto operators = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_operators(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // optional string summary = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_summary(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.OperatorSetProto.summary"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* OperatorSetProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.OperatorSetProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.OperatorProto operators = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_operators_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_operators(i), target, stream); + } + + cached_has_bits = _has_bits_[0]; + // optional string summary = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_summary().data(), static_cast(this->_internal_summary().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.OperatorSetProto.summary"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_summary(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.OperatorSetProto) + return target; +} + +size_t OperatorSetProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.OperatorSetProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.OperatorProto operators = 1; + total_size += 1UL * this->_internal_operators_size(); + for (const auto& msg : this->operators_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // optional string summary = 2; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_summary()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void OperatorSetProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.OperatorSetProto) + GOOGLE_DCHECK_NE(&from, this); + const OperatorSetProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.OperatorSetProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.OperatorSetProto) + MergeFrom(*source); + } +} + +void OperatorSetProto::MergeFrom(const OperatorSetProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.OperatorSetProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + operators_.MergeFrom(from.operators_); + if (from._internal_has_summary()) { + _internal_set_summary(from._internal_summary()); + } +} + +void OperatorSetProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.OperatorSetProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void OperatorSetProto::CopyFrom(const OperatorSetProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.OperatorSetProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool OperatorSetProto::IsInitialized() const { + return true; +} + +void OperatorSetProto::InternalSwap(OperatorSetProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + operators_.InternalSwap(&other->operators_); + summary_.Swap(&other->summary_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata OperatorSetProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void GraphProto::InitAsDefaultInstance() { +} +class GraphProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +GraphProto::GraphProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + node_(arena), + parameters_(arena), + outputs_(arena), + const_vals_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.GraphProto) +} +GraphProto::GraphProto(const GraphProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + node_(from.node_), + parameters_(from.parameters_), + outputs_(from.outputs_), + const_vals_(from.const_vals_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.GraphProto) +} + +void GraphProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +GraphProto::~GraphProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.GraphProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void GraphProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void GraphProto::ArenaDtor(void* object) { + GraphProto* _this = reinterpret_cast< GraphProto* >(object); + (void)_this; +} +void GraphProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void GraphProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const GraphProto& GraphProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void GraphProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.GraphProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + node_.Clear(); + parameters_.Clear(); + outputs_.Clear(); + const_vals_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* GraphProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.NodeProto node = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_node(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // optional string name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.GraphProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.ParameterProto parameters = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_parameters(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<26>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.OutputProto outputs = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_outputs(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<34>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.NamedValueProto const_vals = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_const_vals(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<42>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* GraphProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.GraphProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.NodeProto node = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_node_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_node(i), target, stream); + } + + cached_has_bits = _has_bits_[0]; + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.GraphProto.name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_name(), target); + } + + // repeated .mindspore.irpb.ParameterProto parameters = 3; + for (unsigned int i = 0, + n = static_cast(this->_internal_parameters_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(3, this->_internal_parameters(i), target, stream); + } + + // repeated .mindspore.irpb.OutputProto outputs = 4; + for (unsigned int i = 0, + n = static_cast(this->_internal_outputs_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(4, this->_internal_outputs(i), target, stream); + } + + // repeated .mindspore.irpb.NamedValueProto const_vals = 5; + for (unsigned int i = 0, + n = static_cast(this->_internal_const_vals_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(5, this->_internal_const_vals(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.GraphProto) + return target; +} + +size_t GraphProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.GraphProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.NodeProto node = 1; + total_size += 1UL * this->_internal_node_size(); + for (const auto& msg : this->node_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.ParameterProto parameters = 3; + total_size += 1UL * this->_internal_parameters_size(); + for (const auto& msg : this->parameters_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.OutputProto outputs = 4; + total_size += 1UL * this->_internal_outputs_size(); + for (const auto& msg : this->outputs_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.NamedValueProto const_vals = 5; + total_size += 1UL * this->_internal_const_vals_size(); + for (const auto& msg : this->const_vals_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // optional string name = 2; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void GraphProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.GraphProto) + GOOGLE_DCHECK_NE(&from, this); + const GraphProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.GraphProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.GraphProto) + MergeFrom(*source); + } +} + +void GraphProto::MergeFrom(const GraphProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.GraphProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + node_.MergeFrom(from.node_); + parameters_.MergeFrom(from.parameters_); + outputs_.MergeFrom(from.outputs_); + const_vals_.MergeFrom(from.const_vals_); + if (from._internal_has_name()) { + _internal_set_name(from._internal_name()); + } +} + +void GraphProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.GraphProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void GraphProto::CopyFrom(const GraphProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.GraphProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool GraphProto::IsInitialized() const { + return true; +} + +void GraphProto::InternalSwap(GraphProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + node_.InternalSwap(&other->node_); + parameters_.InternalSwap(&other->parameters_); + outputs_.InternalSwap(&other->outputs_); + const_vals_.InternalSwap(&other->const_vals_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata GraphProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorProto::InitAsDefaultInstance() { +} +class TensorProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_data_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_raw_data(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +TensorProto::TensorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + dims_(arena), + float_data_(arena), + int32_data_(arena), + int64_data_(arena), + double_data_(arena), + uint64_data_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.TensorProto) +} +TensorProto::TensorProto(const TensorProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + dims_(from.dims_), + float_data_(from.float_data_), + int32_data_(from.int32_data_), + int64_data_(from.int64_data_), + double_data_(from.double_data_), + uint64_data_(from.uint64_data_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + raw_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_raw_data()) { + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_raw_data(), + GetArena()); + } + data_type_ = from.data_type_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.TensorProto) +} + +void TensorProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorProto_mindspore_5fanf_5fir_2eproto.base); + raw_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + data_type_ = 0; +} + +TensorProto::~TensorProto() { + // @@protoc_insertion_point(destructor:mindspore.irpb.TensorProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + raw_data_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void TensorProto::ArenaDtor(void* object) { + TensorProto* _this = reinterpret_cast< TensorProto* >(object); + (void)_this; +} +void TensorProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorProto& TensorProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorProto_mindspore_5fanf_5fir_2eproto.base); + return *internal_default_instance(); +} + + +void TensorProto::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.TensorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + dims_.Clear(); + float_data_.Clear(); + int32_data_.Clear(); + int64_data_.Clear(); + double_data_.Clear(); + uint64_data_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + raw_data_.ClearNonDefaultToEmpty(); + } + data_type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated int64 dims = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_dims(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<8>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_dims(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.DataType data_type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::mindspore::irpb::DataType_IsValid(val))) { + _internal_set_data_type(static_cast<::mindspore::irpb::DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(2, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // repeated float float_data = 3 [packed = true]; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_float_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 29) { + _internal_add_float_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // repeated int32 int32_data = 4 [packed = true]; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_int32_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32) { + _internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated int64 int64_data = 5 [packed = true]; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_int64_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 40) { + _internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated double double_data = 6 [packed = true]; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_double_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 49) { + _internal_add_double_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // repeated uint64 uint64_data = 7 [packed = true]; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedUInt64Parser(_internal_mutable_uint64_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 56) { + _internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional bytes raw_data = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + auto str = _internal_mutable_raw_data(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.TensorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated int64 dims = 1; + for (int i = 0, n = this->_internal_dims_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_dims(i), target); + } + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.DataType data_type = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 2, this->_internal_data_type(), target); + } + + // repeated float float_data = 3 [packed = true]; + if (this->_internal_float_data_size() > 0) { + target = stream->WriteFixedPacked(3, _internal_float_data(), target); + } + + // repeated int32 int32_data = 4 [packed = true]; + { + int byte_size = _int32_data_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt32Packed( + 4, _internal_int32_data(), byte_size, target); + } + } + + // repeated int64 int64_data = 5 [packed = true]; + { + int byte_size = _int64_data_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt64Packed( + 5, _internal_int64_data(), byte_size, target); + } + } + + // repeated double double_data = 6 [packed = true]; + if (this->_internal_double_data_size() > 0) { + target = stream->WriteFixedPacked(6, _internal_double_data(), target); + } + + // repeated uint64 uint64_data = 7 [packed = true]; + { + int byte_size = _uint64_data_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteUInt64Packed( + 7, _internal_uint64_data(), byte_size, target); + } + } + + // optional bytes raw_data = 8; + if (cached_has_bits & 0x00000001u) { + target = stream->WriteBytesMaybeAliased( + 8, this->_internal_raw_data(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.TensorProto) + return target; +} + +size_t TensorProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.TensorProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated int64 dims = 1; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->dims_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_dims_size()); + total_size += data_size; + } + + // repeated float float_data = 3 [packed = true]; + { + unsigned int count = static_cast(this->_internal_float_data_size()); + size_t data_size = 4UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _float_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated int32 int32_data = 4 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->int32_data_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _int32_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated int64 int64_data = 5 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->int64_data_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _int64_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated double double_data = 6 [packed = true]; + { + unsigned int count = static_cast(this->_internal_double_data_size()); + size_t data_size = 8UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _double_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated uint64 uint64_data = 7 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + UInt64Size(this->uint64_data_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _uint64_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional bytes raw_data = 8; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_raw_data()); + } + + // optional .mindspore.irpb.DataType data_type = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_data_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.TensorProto) + GOOGLE_DCHECK_NE(&from, this); + const TensorProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.TensorProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.TensorProto) + MergeFrom(*source); + } +} + +void TensorProto::MergeFrom(const TensorProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.TensorProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + dims_.MergeFrom(from.dims_); + float_data_.MergeFrom(from.float_data_); + int32_data_.MergeFrom(from.int32_data_); + int64_data_.MergeFrom(from.int64_data_); + double_data_.MergeFrom(from.double_data_); + uint64_data_.MergeFrom(from.uint64_data_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_raw_data(from._internal_raw_data()); + } + if (cached_has_bits & 0x00000002u) { + data_type_ = from.data_type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TensorProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.TensorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorProto::CopyFrom(const TensorProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.TensorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorProto::IsInitialized() const { + return true; +} + +void TensorProto::InternalSwap(TensorProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + dims_.InternalSwap(&other->dims_); + float_data_.InternalSwap(&other->float_data_); + int32_data_.InternalSwap(&other->int32_data_); + int64_data_.InternalSwap(&other->int64_data_); + double_data_.InternalSwap(&other->double_data_); + uint64_data_.InternalSwap(&other->uint64_data_); + raw_data_.Swap(&other->raw_data_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(data_type_, other->data_type_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace irpb +} // namespace mindspore +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::mindspore::irpb::ValueProto* Arena::CreateMaybeMessage< ::mindspore::irpb::ValueProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::ValueProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::AttributeProto* Arena::CreateMaybeMessage< ::mindspore::irpb::AttributeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::AttributeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::NamedValueProto* Arena::CreateMaybeMessage< ::mindspore::irpb::NamedValueProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::NamedValueProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::TensorShapeProto_Dimension* Arena::CreateMaybeMessage< ::mindspore::irpb::TensorShapeProto_Dimension >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::TensorShapeProto_Dimension >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::TensorShapeProto* Arena::CreateMaybeMessage< ::mindspore::irpb::TensorShapeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::TensorShapeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::TypeProto_Tensor* Arena::CreateMaybeMessage< ::mindspore::irpb::TypeProto_Tensor >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::TypeProto_Tensor >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::TypeProto_Sequence* Arena::CreateMaybeMessage< ::mindspore::irpb::TypeProto_Sequence >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::TypeProto_Sequence >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::TypeProto* Arena::CreateMaybeMessage< ::mindspore::irpb::TypeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::TypeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::ParameterProto* Arena::CreateMaybeMessage< ::mindspore::irpb::ParameterProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::ParameterProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::OutputProto* Arena::CreateMaybeMessage< ::mindspore::irpb::OutputProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::OutputProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::InputProto* Arena::CreateMaybeMessage< ::mindspore::irpb::InputProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::InputProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::NodeProto* Arena::CreateMaybeMessage< ::mindspore::irpb::NodeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::NodeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::ModelProto* Arena::CreateMaybeMessage< ::mindspore::irpb::ModelProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::ModelProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::OperatorProto* Arena::CreateMaybeMessage< ::mindspore::irpb::OperatorProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::OperatorProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::OperatorSetProto* Arena::CreateMaybeMessage< ::mindspore::irpb::OperatorSetProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::OperatorSetProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::GraphProto* Arena::CreateMaybeMessage< ::mindspore::irpb::GraphProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::GraphProto >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::TensorProto* Arena::CreateMaybeMessage< ::mindspore::irpb::TensorProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::TensorProto >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.pb.h b/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..64be4dc0defdc1211547d5366d86d213ca483b83 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.pb.h @@ -0,0 +1,8235 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mindspore_anf_ir.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_mindspore_5fanf_5fir_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_mindspore_5fanf_5fir_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_mindspore_5fanf_5fir_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[17] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_mindspore_5fanf_5fir_2eproto; +namespace mindspore { +namespace irpb { +class AttributeProto; +class AttributeProtoDefaultTypeInternal; +extern AttributeProtoDefaultTypeInternal _AttributeProto_default_instance_; +class GraphProto; +class GraphProtoDefaultTypeInternal; +extern GraphProtoDefaultTypeInternal _GraphProto_default_instance_; +class InputProto; +class InputProtoDefaultTypeInternal; +extern InputProtoDefaultTypeInternal _InputProto_default_instance_; +class ModelProto; +class ModelProtoDefaultTypeInternal; +extern ModelProtoDefaultTypeInternal _ModelProto_default_instance_; +class NamedValueProto; +class NamedValueProtoDefaultTypeInternal; +extern NamedValueProtoDefaultTypeInternal _NamedValueProto_default_instance_; +class NodeProto; +class NodeProtoDefaultTypeInternal; +extern NodeProtoDefaultTypeInternal _NodeProto_default_instance_; +class OperatorProto; +class OperatorProtoDefaultTypeInternal; +extern OperatorProtoDefaultTypeInternal _OperatorProto_default_instance_; +class OperatorSetProto; +class OperatorSetProtoDefaultTypeInternal; +extern OperatorSetProtoDefaultTypeInternal _OperatorSetProto_default_instance_; +class OutputProto; +class OutputProtoDefaultTypeInternal; +extern OutputProtoDefaultTypeInternal _OutputProto_default_instance_; +class ParameterProto; +class ParameterProtoDefaultTypeInternal; +extern ParameterProtoDefaultTypeInternal _ParameterProto_default_instance_; +class TensorProto; +class TensorProtoDefaultTypeInternal; +extern TensorProtoDefaultTypeInternal _TensorProto_default_instance_; +class TensorShapeProto; +class TensorShapeProtoDefaultTypeInternal; +extern TensorShapeProtoDefaultTypeInternal _TensorShapeProto_default_instance_; +class TensorShapeProto_Dimension; +class TensorShapeProto_DimensionDefaultTypeInternal; +extern TensorShapeProto_DimensionDefaultTypeInternal _TensorShapeProto_Dimension_default_instance_; +class TypeProto; +class TypeProtoDefaultTypeInternal; +extern TypeProtoDefaultTypeInternal _TypeProto_default_instance_; +class TypeProto_Sequence; +class TypeProto_SequenceDefaultTypeInternal; +extern TypeProto_SequenceDefaultTypeInternal _TypeProto_Sequence_default_instance_; +class TypeProto_Tensor; +class TypeProto_TensorDefaultTypeInternal; +extern TypeProto_TensorDefaultTypeInternal _TypeProto_Tensor_default_instance_; +class ValueProto; +class ValueProtoDefaultTypeInternal; +extern ValueProtoDefaultTypeInternal _ValueProto_default_instance_; +} // namespace irpb +} // namespace mindspore +PROTOBUF_NAMESPACE_OPEN +template<> ::mindspore::irpb::AttributeProto* Arena::CreateMaybeMessage<::mindspore::irpb::AttributeProto>(Arena*); +template<> ::mindspore::irpb::GraphProto* Arena::CreateMaybeMessage<::mindspore::irpb::GraphProto>(Arena*); +template<> ::mindspore::irpb::InputProto* Arena::CreateMaybeMessage<::mindspore::irpb::InputProto>(Arena*); +template<> ::mindspore::irpb::ModelProto* Arena::CreateMaybeMessage<::mindspore::irpb::ModelProto>(Arena*); +template<> ::mindspore::irpb::NamedValueProto* Arena::CreateMaybeMessage<::mindspore::irpb::NamedValueProto>(Arena*); +template<> ::mindspore::irpb::NodeProto* Arena::CreateMaybeMessage<::mindspore::irpb::NodeProto>(Arena*); +template<> ::mindspore::irpb::OperatorProto* Arena::CreateMaybeMessage<::mindspore::irpb::OperatorProto>(Arena*); +template<> ::mindspore::irpb::OperatorSetProto* Arena::CreateMaybeMessage<::mindspore::irpb::OperatorSetProto>(Arena*); +template<> ::mindspore::irpb::OutputProto* Arena::CreateMaybeMessage<::mindspore::irpb::OutputProto>(Arena*); +template<> ::mindspore::irpb::ParameterProto* Arena::CreateMaybeMessage<::mindspore::irpb::ParameterProto>(Arena*); +template<> ::mindspore::irpb::TensorProto* Arena::CreateMaybeMessage<::mindspore::irpb::TensorProto>(Arena*); +template<> ::mindspore::irpb::TensorShapeProto* Arena::CreateMaybeMessage<::mindspore::irpb::TensorShapeProto>(Arena*); +template<> ::mindspore::irpb::TensorShapeProto_Dimension* Arena::CreateMaybeMessage<::mindspore::irpb::TensorShapeProto_Dimension>(Arena*); +template<> ::mindspore::irpb::TypeProto* Arena::CreateMaybeMessage<::mindspore::irpb::TypeProto>(Arena*); +template<> ::mindspore::irpb::TypeProto_Sequence* Arena::CreateMaybeMessage<::mindspore::irpb::TypeProto_Sequence>(Arena*); +template<> ::mindspore::irpb::TypeProto_Tensor* Arena::CreateMaybeMessage<::mindspore::irpb::TypeProto_Tensor>(Arena*); +template<> ::mindspore::irpb::ValueProto* Arena::CreateMaybeMessage<::mindspore::irpb::ValueProto>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace mindspore { +namespace irpb { + +enum InputProto_EdgeType : int { + InputProto_EdgeType_DATA_EDGE = 0, + InputProto_EdgeType_CONTROL_EDGE = 1 +}; +bool InputProto_EdgeType_IsValid(int value); +constexpr InputProto_EdgeType InputProto_EdgeType_EdgeType_MIN = InputProto_EdgeType_DATA_EDGE; +constexpr InputProto_EdgeType InputProto_EdgeType_EdgeType_MAX = InputProto_EdgeType_CONTROL_EDGE; +constexpr int InputProto_EdgeType_EdgeType_ARRAYSIZE = InputProto_EdgeType_EdgeType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* InputProto_EdgeType_descriptor(); +template +inline const std::string& InputProto_EdgeType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function InputProto_EdgeType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + InputProto_EdgeType_descriptor(), enum_t_value); +} +inline bool InputProto_EdgeType_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, InputProto_EdgeType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + InputProto_EdgeType_descriptor(), name, value); +} +enum Version : int { + UNKNOWWN_VERSION = 0, + IR_VERSION = 1 +}; +bool Version_IsValid(int value); +constexpr Version Version_MIN = UNKNOWWN_VERSION; +constexpr Version Version_MAX = IR_VERSION; +constexpr int Version_ARRAYSIZE = Version_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Version_descriptor(); +template +inline const std::string& Version_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Version_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + Version_descriptor(), enum_t_value); +} +inline bool Version_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, Version* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + Version_descriptor(), name, value); +} +enum DataType : int { + DT_UNDEFINED = 0, + DT_BOOL = 1, + DT_INT8 = 2, + DT_INT16 = 3, + DT_INT32 = 4, + DT_INT64 = 5, + DT_UINT8 = 6, + DT_UINT16 = 7, + DT_UINT32 = 8, + DT_UINT64 = 9, + DT_FLOAT16 = 10, + DT_FLOAT32 = 11, + DT_FLOAT64 = 12, + DT_STRING = 13, + DT_TENSOR = 14, + DT_GRAPH = 15, + DT_BOOLS = 16, + DT_INTS8 = 17, + DT_INTS16 = 18, + DT_INTS32 = 19, + DT_INTS64 = 20, + DT_UINTS8 = 21, + DT_UINTS16 = 22, + DT_UINTS32 = 23, + DT_UINTS64 = 24, + DT_FLOATS16 = 25, + DT_FLOATS32 = 26, + DT_FLOATS64 = 27, + DT_STRINGS = 28, + DT_TENSORS = 29, + DT_GRAPHS = 30, + DT_TUPLE = 31, + DT_LIST = 32, + DT_DICT = 33, + DT_NONE = 34, + DT_SYM_INST = 35, + DT_BASE_INT = 36, + DT_BASE_UINT = 37, + DT_BASE_FLOAT = 38, + DT_TYPE = 39, + DT_ANY = 40, + DT_REFKEY = 41, + DT_REF = 42, + DT_COMPLEX64 = 43, + DT_COMPLEX128 = 44, + DT_BASE_COMPLEX = 45, + DT_BFLOAT16 = 46, + DT_BFLOATS16 = 47, + DT_INT4 = 48, + DT_SLICE = 49 +}; +bool DataType_IsValid(int value); +constexpr DataType DataType_MIN = DT_UNDEFINED; +constexpr DataType DataType_MAX = DT_SLICE; +constexpr int DataType_ARRAYSIZE = DataType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataType_descriptor(); +template +inline const std::string& DataType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function DataType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + DataType_descriptor(), enum_t_value); +} +inline bool DataType_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, DataType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + DataType_descriptor(), name, value); +} +// =================================================================== + +class ValueProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.ValueProto) */ { + public: + inline ValueProto() : ValueProto(nullptr) {} + virtual ~ValueProto(); + + ValueProto(const ValueProto& from); + ValueProto(ValueProto&& from) noexcept + : ValueProto() { + *this = ::std::move(from); + } + + inline ValueProto& operator=(const ValueProto& from) { + CopyFrom(from); + return *this; + } + inline ValueProto& operator=(ValueProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ValueProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ValueProto* internal_default_instance() { + return reinterpret_cast( + &_ValueProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(ValueProto& a, ValueProto& b) { + a.Swap(&b); + } + inline void Swap(ValueProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ValueProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ValueProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ValueProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ValueProto& from); + void MergeFrom(const ValueProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ValueProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.ValueProto"; + } + protected: + explicit ValueProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kBoolValsFieldNumber = 10, + kIntValsFieldNumber = 11, + kUintValsFieldNumber = 12, + kFloatValsFieldNumber = 13, + kDoubleValsFieldNumber = 14, + kStrValsFieldNumber = 15, + kTensorValsFieldNumber = 16, + kGraphsFieldNumber = 17, + kValuesFieldNumber = 18, + kDictValFieldNumber = 19, + kStrValFieldNumber = 7, + kTensorValFieldNumber = 8, + kGraphFieldNumber = 9, + kTypeValFieldNumber = 20, + kDtypeFieldNumber = 1, + kBoolValFieldNumber = 2, + kIntValFieldNumber = 3, + kUintValFieldNumber = 4, + kDoubleValFieldNumber = 6, + kFloatValFieldNumber = 5, + }; + // repeated bool bool_vals = 10; + int bool_vals_size() const; + private: + int _internal_bool_vals_size() const; + public: + void clear_bool_vals(); + private: + bool _internal_bool_vals(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& + _internal_bool_vals() const; + void _internal_add_bool_vals(bool value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* + _internal_mutable_bool_vals(); + public: + bool bool_vals(int index) const; + void set_bool_vals(int index, bool value); + void add_bool_vals(bool value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& + bool_vals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* + mutable_bool_vals(); + + // repeated int64 int_vals = 11; + int int_vals_size() const; + private: + int _internal_int_vals_size() const; + public: + void clear_int_vals(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_int_vals(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_int_vals() const; + void _internal_add_int_vals(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_int_vals(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 int_vals(int index) const; + void set_int_vals(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_int_vals(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + int_vals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_int_vals(); + + // repeated uint64 uint_vals = 12; + int uint_vals_size() const; + private: + int _internal_uint_vals_size() const; + public: + void clear_uint_vals(); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_uint_vals(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + _internal_uint_vals() const; + void _internal_add_uint_vals(::PROTOBUF_NAMESPACE_ID::uint64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + _internal_mutable_uint_vals(); + public: + ::PROTOBUF_NAMESPACE_ID::uint64 uint_vals(int index) const; + void set_uint_vals(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value); + void add_uint_vals(::PROTOBUF_NAMESPACE_ID::uint64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + uint_vals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + mutable_uint_vals(); + + // repeated float float_vals = 13; + int float_vals_size() const; + private: + int _internal_float_vals_size() const; + public: + void clear_float_vals(); + private: + float _internal_float_vals(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_float_vals() const; + void _internal_add_float_vals(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_float_vals(); + public: + float float_vals(int index) const; + void set_float_vals(int index, float value); + void add_float_vals(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + float_vals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_float_vals(); + + // repeated double double_vals = 14; + int double_vals_size() const; + private: + int _internal_double_vals_size() const; + public: + void clear_double_vals(); + private: + double _internal_double_vals(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_double_vals() const; + void _internal_add_double_vals(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_double_vals(); + public: + double double_vals(int index) const; + void set_double_vals(int index, double value); + void add_double_vals(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + double_vals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_double_vals(); + + // repeated string str_vals = 15; + int str_vals_size() const; + private: + int _internal_str_vals_size() const; + public: + void clear_str_vals(); + const std::string& str_vals(int index) const; + std::string* mutable_str_vals(int index); + void set_str_vals(int index, const std::string& value); + void set_str_vals(int index, std::string&& value); + void set_str_vals(int index, const char* value); + void set_str_vals(int index, const char* value, size_t size); + std::string* add_str_vals(); + void add_str_vals(const std::string& value); + void add_str_vals(std::string&& value); + void add_str_vals(const char* value); + void add_str_vals(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& str_vals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_str_vals(); + private: + const std::string& _internal_str_vals(int index) const; + std::string* _internal_add_str_vals(); + public: + + // repeated .mindspore.irpb.TensorProto tensor_vals = 16; + int tensor_vals_size() const; + private: + int _internal_tensor_vals_size() const; + public: + void clear_tensor_vals(); + ::mindspore::irpb::TensorProto* mutable_tensor_vals(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorProto >* + mutable_tensor_vals(); + private: + const ::mindspore::irpb::TensorProto& _internal_tensor_vals(int index) const; + ::mindspore::irpb::TensorProto* _internal_add_tensor_vals(); + public: + const ::mindspore::irpb::TensorProto& tensor_vals(int index) const; + ::mindspore::irpb::TensorProto* add_tensor_vals(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorProto >& + tensor_vals() const; + + // repeated .mindspore.irpb.GraphProto graphs = 17; + int graphs_size() const; + private: + int _internal_graphs_size() const; + public: + void clear_graphs(); + ::mindspore::irpb::GraphProto* mutable_graphs(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::GraphProto >* + mutable_graphs(); + private: + const ::mindspore::irpb::GraphProto& _internal_graphs(int index) const; + ::mindspore::irpb::GraphProto* _internal_add_graphs(); + public: + const ::mindspore::irpb::GraphProto& graphs(int index) const; + ::mindspore::irpb::GraphProto* add_graphs(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::GraphProto >& + graphs() const; + + // repeated .mindspore.irpb.ValueProto values = 18; + int values_size() const; + private: + int _internal_values_size() const; + public: + void clear_values(); + ::mindspore::irpb::ValueProto* mutable_values(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ValueProto >* + mutable_values(); + private: + const ::mindspore::irpb::ValueProto& _internal_values(int index) const; + ::mindspore::irpb::ValueProto* _internal_add_values(); + public: + const ::mindspore::irpb::ValueProto& values(int index) const; + ::mindspore::irpb::ValueProto* add_values(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ValueProto >& + values() const; + + // repeated .mindspore.irpb.NamedValueProto dict_val = 19; + int dict_val_size() const; + private: + int _internal_dict_val_size() const; + public: + void clear_dict_val(); + ::mindspore::irpb::NamedValueProto* mutable_dict_val(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >* + mutable_dict_val(); + private: + const ::mindspore::irpb::NamedValueProto& _internal_dict_val(int index) const; + ::mindspore::irpb::NamedValueProto* _internal_add_dict_val(); + public: + const ::mindspore::irpb::NamedValueProto& dict_val(int index) const; + ::mindspore::irpb::NamedValueProto* add_dict_val(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >& + dict_val() const; + + // optional string str_val = 7; + bool has_str_val() const; + private: + bool _internal_has_str_val() const; + public: + void clear_str_val(); + const std::string& str_val() const; + void set_str_val(const std::string& value); + void set_str_val(std::string&& value); + void set_str_val(const char* value); + void set_str_val(const char* value, size_t size); + std::string* mutable_str_val(); + std::string* release_str_val(); + void set_allocated_str_val(std::string* str_val); + private: + const std::string& _internal_str_val() const; + void _internal_set_str_val(const std::string& value); + std::string* _internal_mutable_str_val(); + public: + + // optional .mindspore.irpb.TensorProto tensor_val = 8; + bool has_tensor_val() const; + private: + bool _internal_has_tensor_val() const; + public: + void clear_tensor_val(); + const ::mindspore::irpb::TensorProto& tensor_val() const; + ::mindspore::irpb::TensorProto* release_tensor_val(); + ::mindspore::irpb::TensorProto* mutable_tensor_val(); + void set_allocated_tensor_val(::mindspore::irpb::TensorProto* tensor_val); + private: + const ::mindspore::irpb::TensorProto& _internal_tensor_val() const; + ::mindspore::irpb::TensorProto* _internal_mutable_tensor_val(); + public: + void unsafe_arena_set_allocated_tensor_val( + ::mindspore::irpb::TensorProto* tensor_val); + ::mindspore::irpb::TensorProto* unsafe_arena_release_tensor_val(); + + // optional .mindspore.irpb.GraphProto graph = 9; + bool has_graph() const; + private: + bool _internal_has_graph() const; + public: + void clear_graph(); + const ::mindspore::irpb::GraphProto& graph() const; + ::mindspore::irpb::GraphProto* release_graph(); + ::mindspore::irpb::GraphProto* mutable_graph(); + void set_allocated_graph(::mindspore::irpb::GraphProto* graph); + private: + const ::mindspore::irpb::GraphProto& _internal_graph() const; + ::mindspore::irpb::GraphProto* _internal_mutable_graph(); + public: + void unsafe_arena_set_allocated_graph( + ::mindspore::irpb::GraphProto* graph); + ::mindspore::irpb::GraphProto* unsafe_arena_release_graph(); + + // optional .mindspore.irpb.TypeProto type_val = 20; + bool has_type_val() const; + private: + bool _internal_has_type_val() const; + public: + void clear_type_val(); + const ::mindspore::irpb::TypeProto& type_val() const; + ::mindspore::irpb::TypeProto* release_type_val(); + ::mindspore::irpb::TypeProto* mutable_type_val(); + void set_allocated_type_val(::mindspore::irpb::TypeProto* type_val); + private: + const ::mindspore::irpb::TypeProto& _internal_type_val() const; + ::mindspore::irpb::TypeProto* _internal_mutable_type_val(); + public: + void unsafe_arena_set_allocated_type_val( + ::mindspore::irpb::TypeProto* type_val); + ::mindspore::irpb::TypeProto* unsafe_arena_release_type_val(); + + // optional .mindspore.irpb.DataType dtype = 1; + bool has_dtype() const; + private: + bool _internal_has_dtype() const; + public: + void clear_dtype(); + ::mindspore::irpb::DataType dtype() const; + void set_dtype(::mindspore::irpb::DataType value); + private: + ::mindspore::irpb::DataType _internal_dtype() const; + void _internal_set_dtype(::mindspore::irpb::DataType value); + public: + + // optional bool bool_val = 2; + bool has_bool_val() const; + private: + bool _internal_has_bool_val() const; + public: + void clear_bool_val(); + bool bool_val() const; + void set_bool_val(bool value); + private: + bool _internal_bool_val() const; + void _internal_set_bool_val(bool value); + public: + + // optional int64 int_val = 3; + bool has_int_val() const; + private: + bool _internal_has_int_val() const; + public: + void clear_int_val(); + ::PROTOBUF_NAMESPACE_ID::int64 int_val() const; + void set_int_val(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_int_val() const; + void _internal_set_int_val(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional uint64 uint_val = 4; + bool has_uint_val() const; + private: + bool _internal_has_uint_val() const; + public: + void clear_uint_val(); + ::PROTOBUF_NAMESPACE_ID::uint64 uint_val() const; + void set_uint_val(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_uint_val() const; + void _internal_set_uint_val(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + + // optional double double_val = 6; + bool has_double_val() const; + private: + bool _internal_has_double_val() const; + public: + void clear_double_val(); + double double_val() const; + void set_double_val(double value); + private: + double _internal_double_val() const; + void _internal_set_double_val(double value); + public: + + // optional float float_val = 5; + bool has_float_val() const; + private: + bool _internal_has_float_val() const; + public: + void clear_float_val(); + float float_val() const; + void set_float_val(float value); + private: + float _internal_float_val() const; + void _internal_set_float_val(float value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.ValueProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool > bool_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > int_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 > uint_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > double_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField str_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorProto > tensor_vals_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::GraphProto > graphs_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ValueProto > values_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto > dict_val_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr str_val_; + ::mindspore::irpb::TensorProto* tensor_val_; + ::mindspore::irpb::GraphProto* graph_; + ::mindspore::irpb::TypeProto* type_val_; + int dtype_; + bool bool_val_; + ::PROTOBUF_NAMESPACE_ID::int64 int_val_; + ::PROTOBUF_NAMESPACE_ID::uint64 uint_val_; + double double_val_; + float float_val_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class AttributeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.AttributeProto) */ { + public: + inline AttributeProto() : AttributeProto(nullptr) {} + virtual ~AttributeProto(); + + AttributeProto(const AttributeProto& from); + AttributeProto(AttributeProto&& from) noexcept + : AttributeProto() { + *this = ::std::move(from); + } + + inline AttributeProto& operator=(const AttributeProto& from) { + CopyFrom(from); + return *this; + } + inline AttributeProto& operator=(AttributeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const AttributeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const AttributeProto* internal_default_instance() { + return reinterpret_cast( + &_AttributeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(AttributeProto& a, AttributeProto& b) { + a.Swap(&b); + } + inline void Swap(AttributeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(AttributeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline AttributeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + AttributeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const AttributeProto& from); + void MergeFrom(const AttributeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(AttributeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.AttributeProto"; + } + protected: + explicit AttributeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kValueFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .mindspore.irpb.ValueProto value = 2; + bool has_value() const; + private: + bool _internal_has_value() const; + public: + void clear_value(); + const ::mindspore::irpb::ValueProto& value() const; + ::mindspore::irpb::ValueProto* release_value(); + ::mindspore::irpb::ValueProto* mutable_value(); + void set_allocated_value(::mindspore::irpb::ValueProto* value); + private: + const ::mindspore::irpb::ValueProto& _internal_value() const; + ::mindspore::irpb::ValueProto* _internal_mutable_value(); + public: + void unsafe_arena_set_allocated_value( + ::mindspore::irpb::ValueProto* value); + ::mindspore::irpb::ValueProto* unsafe_arena_release_value(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.AttributeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::mindspore::irpb::ValueProto* value_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class NamedValueProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.NamedValueProto) */ { + public: + inline NamedValueProto() : NamedValueProto(nullptr) {} + virtual ~NamedValueProto(); + + NamedValueProto(const NamedValueProto& from); + NamedValueProto(NamedValueProto&& from) noexcept + : NamedValueProto() { + *this = ::std::move(from); + } + + inline NamedValueProto& operator=(const NamedValueProto& from) { + CopyFrom(from); + return *this; + } + inline NamedValueProto& operator=(NamedValueProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const NamedValueProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const NamedValueProto* internal_default_instance() { + return reinterpret_cast( + &_NamedValueProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(NamedValueProto& a, NamedValueProto& b) { + a.Swap(&b); + } + inline void Swap(NamedValueProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(NamedValueProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline NamedValueProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + NamedValueProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const NamedValueProto& from); + void MergeFrom(const NamedValueProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(NamedValueProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.NamedValueProto"; + } + protected: + explicit NamedValueProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kKeyFieldNumber = 1, + kValueFieldNumber = 2, + }; + // optional string key = 1; + bool has_key() const; + private: + bool _internal_has_key() const; + public: + void clear_key(); + const std::string& key() const; + void set_key(const std::string& value); + void set_key(std::string&& value); + void set_key(const char* value); + void set_key(const char* value, size_t size); + std::string* mutable_key(); + std::string* release_key(); + void set_allocated_key(std::string* key); + private: + const std::string& _internal_key() const; + void _internal_set_key(const std::string& value); + std::string* _internal_mutable_key(); + public: + + // optional .mindspore.irpb.ValueProto value = 2; + bool has_value() const; + private: + bool _internal_has_value() const; + public: + void clear_value(); + const ::mindspore::irpb::ValueProto& value() const; + ::mindspore::irpb::ValueProto* release_value(); + ::mindspore::irpb::ValueProto* mutable_value(); + void set_allocated_value(::mindspore::irpb::ValueProto* value); + private: + const ::mindspore::irpb::ValueProto& _internal_value() const; + ::mindspore::irpb::ValueProto* _internal_mutable_value(); + public: + void unsafe_arena_set_allocated_value( + ::mindspore::irpb::ValueProto* value); + ::mindspore::irpb::ValueProto* unsafe_arena_release_value(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.NamedValueProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr key_; + ::mindspore::irpb::ValueProto* value_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto_Dimension PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.TensorShapeProto.Dimension) */ { + public: + inline TensorShapeProto_Dimension() : TensorShapeProto_Dimension(nullptr) {} + virtual ~TensorShapeProto_Dimension(); + + TensorShapeProto_Dimension(const TensorShapeProto_Dimension& from); + TensorShapeProto_Dimension(TensorShapeProto_Dimension&& from) noexcept + : TensorShapeProto_Dimension() { + *this = ::std::move(from); + } + + inline TensorShapeProto_Dimension& operator=(const TensorShapeProto_Dimension& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto_Dimension& operator=(TensorShapeProto_Dimension&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto_Dimension& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto_Dimension* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_Dimension_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(TensorShapeProto_Dimension& a, TensorShapeProto_Dimension& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto_Dimension* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorShapeProto_Dimension* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto_Dimension* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto_Dimension* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto_Dimension& from); + void MergeFrom(const TensorShapeProto_Dimension& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto_Dimension* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.TensorShapeProto.Dimension"; + } + protected: + explicit TensorShapeProto_Dimension(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 2, + kSizeFieldNumber = 1, + }; + // optional string name = 2; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional int64 size = 1; + bool has_size() const; + private: + bool _internal_has_size() const; + public: + void clear_size(); + ::PROTOBUF_NAMESPACE_ID::int64 size() const; + void set_size(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_size() const; + void _internal_set_size(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.TensorShapeProto.Dimension) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::int64 size_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.TensorShapeProto) */ { + public: + inline TensorShapeProto() : TensorShapeProto(nullptr) {} + virtual ~TensorShapeProto(); + + TensorShapeProto(const TensorShapeProto& from); + TensorShapeProto(TensorShapeProto&& from) noexcept + : TensorShapeProto() { + *this = ::std::move(from); + } + + inline TensorShapeProto& operator=(const TensorShapeProto& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto& operator=(TensorShapeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(TensorShapeProto& a, TensorShapeProto& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorShapeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto& from); + void MergeFrom(const TensorShapeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.TensorShapeProto"; + } + protected: + explicit TensorShapeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TensorShapeProto_Dimension Dimension; + + // accessors ------------------------------------------------------- + + enum : int { + kDimFieldNumber = 1, + }; + // repeated .mindspore.irpb.TensorShapeProto.Dimension dim = 1; + int dim_size() const; + private: + int _internal_dim_size() const; + public: + void clear_dim(); + ::mindspore::irpb::TensorShapeProto_Dimension* mutable_dim(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorShapeProto_Dimension >* + mutable_dim(); + private: + const ::mindspore::irpb::TensorShapeProto_Dimension& _internal_dim(int index) const; + ::mindspore::irpb::TensorShapeProto_Dimension* _internal_add_dim(); + public: + const ::mindspore::irpb::TensorShapeProto_Dimension& dim(int index) const; + ::mindspore::irpb::TensorShapeProto_Dimension* add_dim(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorShapeProto_Dimension >& + dim() const; + + // @@protoc_insertion_point(class_scope:mindspore.irpb.TensorShapeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorShapeProto_Dimension > dim_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Tensor PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.TypeProto.Tensor) */ { + public: + inline TypeProto_Tensor() : TypeProto_Tensor(nullptr) {} + virtual ~TypeProto_Tensor(); + + TypeProto_Tensor(const TypeProto_Tensor& from); + TypeProto_Tensor(TypeProto_Tensor&& from) noexcept + : TypeProto_Tensor() { + *this = ::std::move(from); + } + + inline TypeProto_Tensor& operator=(const TypeProto_Tensor& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Tensor& operator=(TypeProto_Tensor&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Tensor& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Tensor* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Tensor_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(TypeProto_Tensor& a, TypeProto_Tensor& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Tensor* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_Tensor* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Tensor* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Tensor* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Tensor& from); + void MergeFrom(const TypeProto_Tensor& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Tensor* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.TypeProto.Tensor"; + } + protected: + explicit TypeProto_Tensor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kShapeFieldNumber = 2, + kElemTypeFieldNumber = 1, + }; + // optional .mindspore.irpb.TensorShapeProto shape = 2; + bool has_shape() const; + private: + bool _internal_has_shape() const; + public: + void clear_shape(); + const ::mindspore::irpb::TensorShapeProto& shape() const; + ::mindspore::irpb::TensorShapeProto* release_shape(); + ::mindspore::irpb::TensorShapeProto* mutable_shape(); + void set_allocated_shape(::mindspore::irpb::TensorShapeProto* shape); + private: + const ::mindspore::irpb::TensorShapeProto& _internal_shape() const; + ::mindspore::irpb::TensorShapeProto* _internal_mutable_shape(); + public: + void unsafe_arena_set_allocated_shape( + ::mindspore::irpb::TensorShapeProto* shape); + ::mindspore::irpb::TensorShapeProto* unsafe_arena_release_shape(); + + // optional .mindspore.irpb.DataType elem_type = 1; + bool has_elem_type() const; + private: + bool _internal_has_elem_type() const; + public: + void clear_elem_type(); + ::mindspore::irpb::DataType elem_type() const; + void set_elem_type(::mindspore::irpb::DataType value); + private: + ::mindspore::irpb::DataType _internal_elem_type() const; + void _internal_set_elem_type(::mindspore::irpb::DataType value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.TypeProto.Tensor) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::mindspore::irpb::TensorShapeProto* shape_; + int elem_type_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Sequence PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.TypeProto.Sequence) */ { + public: + inline TypeProto_Sequence() : TypeProto_Sequence(nullptr) {} + virtual ~TypeProto_Sequence(); + + TypeProto_Sequence(const TypeProto_Sequence& from); + TypeProto_Sequence(TypeProto_Sequence&& from) noexcept + : TypeProto_Sequence() { + *this = ::std::move(from); + } + + inline TypeProto_Sequence& operator=(const TypeProto_Sequence& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Sequence& operator=(TypeProto_Sequence&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Sequence& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Sequence* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Sequence_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(TypeProto_Sequence& a, TypeProto_Sequence& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Sequence* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_Sequence* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Sequence* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Sequence* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Sequence& from); + void MergeFrom(const TypeProto_Sequence& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Sequence* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.TypeProto.Sequence"; + } + protected: + explicit TypeProto_Sequence(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kElemTypesFieldNumber = 1, + }; + // repeated .mindspore.irpb.TypeProto elem_types = 1; + int elem_types_size() const; + private: + int _internal_elem_types_size() const; + public: + void clear_elem_types(); + ::mindspore::irpb::TypeProto* mutable_elem_types(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TypeProto >* + mutable_elem_types(); + private: + const ::mindspore::irpb::TypeProto& _internal_elem_types(int index) const; + ::mindspore::irpb::TypeProto* _internal_add_elem_types(); + public: + const ::mindspore::irpb::TypeProto& elem_types(int index) const; + ::mindspore::irpb::TypeProto* add_elem_types(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TypeProto >& + elem_types() const; + + // @@protoc_insertion_point(class_scope:mindspore.irpb.TypeProto.Sequence) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TypeProto > elem_types_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.TypeProto) */ { + public: + inline TypeProto() : TypeProto(nullptr) {} + virtual ~TypeProto(); + + TypeProto(const TypeProto& from); + TypeProto(TypeProto&& from) noexcept + : TypeProto() { + *this = ::std::move(from); + } + + inline TypeProto& operator=(const TypeProto& from) { + CopyFrom(from); + return *this; + } + inline TypeProto& operator=(TypeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto& default_instance(); + + enum ValueCase { + kTensorType = 2, + kSequenceType = 3, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 7; + + friend void swap(TypeProto& a, TypeProto& b) { + a.Swap(&b); + } + inline void Swap(TypeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto& from); + void MergeFrom(const TypeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.TypeProto"; + } + protected: + explicit TypeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TypeProto_Tensor Tensor; + typedef TypeProto_Sequence Sequence; + + // accessors ------------------------------------------------------- + + enum : int { + kDataTypeFieldNumber = 1, + kTensorTypeFieldNumber = 2, + kSequenceTypeFieldNumber = 3, + }; + // optional .mindspore.irpb.DataType data_type = 1; + bool has_data_type() const; + private: + bool _internal_has_data_type() const; + public: + void clear_data_type(); + ::mindspore::irpb::DataType data_type() const; + void set_data_type(::mindspore::irpb::DataType value); + private: + ::mindspore::irpb::DataType _internal_data_type() const; + void _internal_set_data_type(::mindspore::irpb::DataType value); + public: + + // .mindspore.irpb.TypeProto.Tensor tensor_type = 2; + bool has_tensor_type() const; + private: + bool _internal_has_tensor_type() const; + public: + void clear_tensor_type(); + const ::mindspore::irpb::TypeProto_Tensor& tensor_type() const; + ::mindspore::irpb::TypeProto_Tensor* release_tensor_type(); + ::mindspore::irpb::TypeProto_Tensor* mutable_tensor_type(); + void set_allocated_tensor_type(::mindspore::irpb::TypeProto_Tensor* tensor_type); + private: + const ::mindspore::irpb::TypeProto_Tensor& _internal_tensor_type() const; + ::mindspore::irpb::TypeProto_Tensor* _internal_mutable_tensor_type(); + public: + void unsafe_arena_set_allocated_tensor_type( + ::mindspore::irpb::TypeProto_Tensor* tensor_type); + ::mindspore::irpb::TypeProto_Tensor* unsafe_arena_release_tensor_type(); + + // .mindspore.irpb.TypeProto.Sequence sequence_type = 3; + bool has_sequence_type() const; + private: + bool _internal_has_sequence_type() const; + public: + void clear_sequence_type(); + const ::mindspore::irpb::TypeProto_Sequence& sequence_type() const; + ::mindspore::irpb::TypeProto_Sequence* release_sequence_type(); + ::mindspore::irpb::TypeProto_Sequence* mutable_sequence_type(); + void set_allocated_sequence_type(::mindspore::irpb::TypeProto_Sequence* sequence_type); + private: + const ::mindspore::irpb::TypeProto_Sequence& _internal_sequence_type() const; + ::mindspore::irpb::TypeProto_Sequence* _internal_mutable_sequence_type(); + public: + void unsafe_arena_set_allocated_sequence_type( + ::mindspore::irpb::TypeProto_Sequence* sequence_type); + ::mindspore::irpb::TypeProto_Sequence* unsafe_arena_release_sequence_type(); + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:mindspore.irpb.TypeProto) + private: + class _Internal; + void set_has_tensor_type(); + void set_has_sequence_type(); + + inline bool has_value() const; + inline void clear_has_value(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + int data_type_; + union ValueUnion { + ValueUnion() {} + ::mindspore::irpb::TypeProto_Tensor* tensor_type_; + ::mindspore::irpb::TypeProto_Sequence* sequence_type_; + } value_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class ParameterProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.ParameterProto) */ { + public: + inline ParameterProto() : ParameterProto(nullptr) {} + virtual ~ParameterProto(); + + ParameterProto(const ParameterProto& from); + ParameterProto(ParameterProto&& from) noexcept + : ParameterProto() { + *this = ::std::move(from); + } + + inline ParameterProto& operator=(const ParameterProto& from) { + CopyFrom(from); + return *this; + } + inline ParameterProto& operator=(ParameterProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ParameterProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ParameterProto* internal_default_instance() { + return reinterpret_cast( + &_ParameterProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 8; + + friend void swap(ParameterProto& a, ParameterProto& b) { + a.Swap(&b); + } + inline void Swap(ParameterProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ParameterProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ParameterProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ParameterProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ParameterProto& from); + void MergeFrom(const ParameterProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ParameterProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.ParameterProto"; + } + protected: + explicit ParameterProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kTypeFieldNumber = 2, + kDefaultValFieldNumber = 3, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .mindspore.irpb.TypeProto type = 2; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + const ::mindspore::irpb::TypeProto& type() const; + ::mindspore::irpb::TypeProto* release_type(); + ::mindspore::irpb::TypeProto* mutable_type(); + void set_allocated_type(::mindspore::irpb::TypeProto* type); + private: + const ::mindspore::irpb::TypeProto& _internal_type() const; + ::mindspore::irpb::TypeProto* _internal_mutable_type(); + public: + void unsafe_arena_set_allocated_type( + ::mindspore::irpb::TypeProto* type); + ::mindspore::irpb::TypeProto* unsafe_arena_release_type(); + + // optional .mindspore.irpb.ValueProto default_val = 3; + bool has_default_val() const; + private: + bool _internal_has_default_val() const; + public: + void clear_default_val(); + const ::mindspore::irpb::ValueProto& default_val() const; + ::mindspore::irpb::ValueProto* release_default_val(); + ::mindspore::irpb::ValueProto* mutable_default_val(); + void set_allocated_default_val(::mindspore::irpb::ValueProto* default_val); + private: + const ::mindspore::irpb::ValueProto& _internal_default_val() const; + ::mindspore::irpb::ValueProto* _internal_mutable_default_val(); + public: + void unsafe_arena_set_allocated_default_val( + ::mindspore::irpb::ValueProto* default_val); + ::mindspore::irpb::ValueProto* unsafe_arena_release_default_val(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.ParameterProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::mindspore::irpb::TypeProto* type_; + ::mindspore::irpb::ValueProto* default_val_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class OutputProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.OutputProto) */ { + public: + inline OutputProto() : OutputProto(nullptr) {} + virtual ~OutputProto(); + + OutputProto(const OutputProto& from); + OutputProto(OutputProto&& from) noexcept + : OutputProto() { + *this = ::std::move(from); + } + + inline OutputProto& operator=(const OutputProto& from) { + CopyFrom(from); + return *this; + } + inline OutputProto& operator=(OutputProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OutputProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OutputProto* internal_default_instance() { + return reinterpret_cast( + &_OutputProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 9; + + friend void swap(OutputProto& a, OutputProto& b) { + a.Swap(&b); + } + inline void Swap(OutputProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(OutputProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OutputProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + OutputProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OutputProto& from); + void MergeFrom(const OutputProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OutputProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.OutputProto"; + } + protected: + explicit OutputProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kTypeFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .mindspore.irpb.TypeProto type = 2; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + const ::mindspore::irpb::TypeProto& type() const; + ::mindspore::irpb::TypeProto* release_type(); + ::mindspore::irpb::TypeProto* mutable_type(); + void set_allocated_type(::mindspore::irpb::TypeProto* type); + private: + const ::mindspore::irpb::TypeProto& _internal_type() const; + ::mindspore::irpb::TypeProto* _internal_mutable_type(); + public: + void unsafe_arena_set_allocated_type( + ::mindspore::irpb::TypeProto* type); + ::mindspore::irpb::TypeProto* unsafe_arena_release_type(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.OutputProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::mindspore::irpb::TypeProto* type_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class InputProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.InputProto) */ { + public: + inline InputProto() : InputProto(nullptr) {} + virtual ~InputProto(); + + InputProto(const InputProto& from); + InputProto(InputProto&& from) noexcept + : InputProto() { + *this = ::std::move(from); + } + + inline InputProto& operator=(const InputProto& from) { + CopyFrom(from); + return *this; + } + inline InputProto& operator=(InputProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const InputProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const InputProto* internal_default_instance() { + return reinterpret_cast( + &_InputProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 10; + + friend void swap(InputProto& a, InputProto& b) { + a.Swap(&b); + } + inline void Swap(InputProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(InputProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline InputProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + InputProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const InputProto& from); + void MergeFrom(const InputProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(InputProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.InputProto"; + } + protected: + explicit InputProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef InputProto_EdgeType EdgeType; + static constexpr EdgeType DATA_EDGE = + InputProto_EdgeType_DATA_EDGE; + static constexpr EdgeType CONTROL_EDGE = + InputProto_EdgeType_CONTROL_EDGE; + static inline bool EdgeType_IsValid(int value) { + return InputProto_EdgeType_IsValid(value); + } + static constexpr EdgeType EdgeType_MIN = + InputProto_EdgeType_EdgeType_MIN; + static constexpr EdgeType EdgeType_MAX = + InputProto_EdgeType_EdgeType_MAX; + static constexpr int EdgeType_ARRAYSIZE = + InputProto_EdgeType_EdgeType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + EdgeType_descriptor() { + return InputProto_EdgeType_descriptor(); + } + template + static inline const std::string& EdgeType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function EdgeType_Name."); + return InputProto_EdgeType_Name(enum_t_value); + } + static inline bool EdgeType_Parse(::PROTOBUF_NAMESPACE_ID::ConstStringParam name, + EdgeType* value) { + return InputProto_EdgeType_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kTypeFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional .mindspore.irpb.InputProto.EdgeType type = 2; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + ::mindspore::irpb::InputProto_EdgeType type() const; + void set_type(::mindspore::irpb::InputProto_EdgeType value); + private: + ::mindspore::irpb::InputProto_EdgeType _internal_type() const; + void _internal_set_type(::mindspore::irpb::InputProto_EdgeType value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.InputProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + int type_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class NodeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.NodeProto) */ { + public: + inline NodeProto() : NodeProto(nullptr) {} + virtual ~NodeProto(); + + NodeProto(const NodeProto& from); + NodeProto(NodeProto&& from) noexcept + : NodeProto() { + *this = ::std::move(from); + } + + inline NodeProto& operator=(const NodeProto& from) { + CopyFrom(from); + return *this; + } + inline NodeProto& operator=(NodeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const NodeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const NodeProto* internal_default_instance() { + return reinterpret_cast( + &_NodeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 11; + + friend void swap(NodeProto& a, NodeProto& b) { + a.Swap(&b); + } + inline void Swap(NodeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(NodeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline NodeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + NodeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const NodeProto& from); + void MergeFrom(const NodeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(NodeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.NodeProto"; + } + protected: + explicit NodeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kInputFieldNumber = 1, + kAttributeFieldNumber = 5, + kNameFieldNumber = 2, + kOpTypeFieldNumber = 3, + kScopeFieldNumber = 4, + kFullNameFieldNumber = 8, + kInstanceNameFieldNumber = 10, + kOutputTypeFieldNumber = 6, + kOutputIFieldNumber = 7, + }; + // repeated .mindspore.irpb.InputProto input = 1; + int input_size() const; + private: + int _internal_input_size() const; + public: + void clear_input(); + ::mindspore::irpb::InputProto* mutable_input(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::InputProto >* + mutable_input(); + private: + const ::mindspore::irpb::InputProto& _internal_input(int index) const; + ::mindspore::irpb::InputProto* _internal_add_input(); + public: + const ::mindspore::irpb::InputProto& input(int index) const; + ::mindspore::irpb::InputProto* add_input(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::InputProto >& + input() const; + + // repeated .mindspore.irpb.AttributeProto attribute = 5; + int attribute_size() const; + private: + int _internal_attribute_size() const; + public: + void clear_attribute(); + ::mindspore::irpb::AttributeProto* mutable_attribute(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::AttributeProto >* + mutable_attribute(); + private: + const ::mindspore::irpb::AttributeProto& _internal_attribute(int index) const; + ::mindspore::irpb::AttributeProto* _internal_add_attribute(); + public: + const ::mindspore::irpb::AttributeProto& attribute(int index) const; + ::mindspore::irpb::AttributeProto* add_attribute(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::AttributeProto >& + attribute() const; + + // optional string name = 2; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string op_type = 3; + bool has_op_type() const; + private: + bool _internal_has_op_type() const; + public: + void clear_op_type(); + const std::string& op_type() const; + void set_op_type(const std::string& value); + void set_op_type(std::string&& value); + void set_op_type(const char* value); + void set_op_type(const char* value, size_t size); + std::string* mutable_op_type(); + std::string* release_op_type(); + void set_allocated_op_type(std::string* op_type); + private: + const std::string& _internal_op_type() const; + void _internal_set_op_type(const std::string& value); + std::string* _internal_mutable_op_type(); + public: + + // optional string scope = 4; + bool has_scope() const; + private: + bool _internal_has_scope() const; + public: + void clear_scope(); + const std::string& scope() const; + void set_scope(const std::string& value); + void set_scope(std::string&& value); + void set_scope(const char* value); + void set_scope(const char* value, size_t size); + std::string* mutable_scope(); + std::string* release_scope(); + void set_allocated_scope(std::string* scope); + private: + const std::string& _internal_scope() const; + void _internal_set_scope(const std::string& value); + std::string* _internal_mutable_scope(); + public: + + // optional string full_name = 8; + bool has_full_name() const; + private: + bool _internal_has_full_name() const; + public: + void clear_full_name(); + const std::string& full_name() const; + void set_full_name(const std::string& value); + void set_full_name(std::string&& value); + void set_full_name(const char* value); + void set_full_name(const char* value, size_t size); + std::string* mutable_full_name(); + std::string* release_full_name(); + void set_allocated_full_name(std::string* full_name); + private: + const std::string& _internal_full_name() const; + void _internal_set_full_name(const std::string& value); + std::string* _internal_mutable_full_name(); + public: + + // optional string instance_name = 10; + bool has_instance_name() const; + private: + bool _internal_has_instance_name() const; + public: + void clear_instance_name(); + const std::string& instance_name() const; + void set_instance_name(const std::string& value); + void set_instance_name(std::string&& value); + void set_instance_name(const char* value); + void set_instance_name(const char* value, size_t size); + std::string* mutable_instance_name(); + std::string* release_instance_name(); + void set_allocated_instance_name(std::string* instance_name); + private: + const std::string& _internal_instance_name() const; + void _internal_set_instance_name(const std::string& value); + std::string* _internal_mutable_instance_name(); + public: + + // optional .mindspore.irpb.TypeProto output_type = 6; + bool has_output_type() const; + private: + bool _internal_has_output_type() const; + public: + void clear_output_type(); + const ::mindspore::irpb::TypeProto& output_type() const; + ::mindspore::irpb::TypeProto* release_output_type(); + ::mindspore::irpb::TypeProto* mutable_output_type(); + void set_allocated_output_type(::mindspore::irpb::TypeProto* output_type); + private: + const ::mindspore::irpb::TypeProto& _internal_output_type() const; + ::mindspore::irpb::TypeProto* _internal_mutable_output_type(); + public: + void unsafe_arena_set_allocated_output_type( + ::mindspore::irpb::TypeProto* output_type); + ::mindspore::irpb::TypeProto* unsafe_arena_release_output_type(); + + // optional uint64 output_i = 7; + bool has_output_i() const; + private: + bool _internal_has_output_i() const; + public: + void clear_output_i(); + ::PROTOBUF_NAMESPACE_ID::uint64 output_i() const; + void set_output_i(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_output_i() const; + void _internal_set_output_i(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.NodeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::InputProto > input_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::AttributeProto > attribute_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr op_type_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr scope_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr full_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr instance_name_; + ::mindspore::irpb::TypeProto* output_type_; + ::PROTOBUF_NAMESPACE_ID::uint64 output_i_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class ModelProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.ModelProto) */ { + public: + inline ModelProto() : ModelProto(nullptr) {} + virtual ~ModelProto(); + + ModelProto(const ModelProto& from); + ModelProto(ModelProto&& from) noexcept + : ModelProto() { + *this = ::std::move(from); + } + + inline ModelProto& operator=(const ModelProto& from) { + CopyFrom(from); + return *this; + } + inline ModelProto& operator=(ModelProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ModelProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ModelProto* internal_default_instance() { + return reinterpret_cast( + &_ModelProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 12; + + friend void swap(ModelProto& a, ModelProto& b) { + a.Swap(&b); + } + inline void Swap(ModelProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ModelProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ModelProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ModelProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ModelProto& from); + void MergeFrom(const ModelProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ModelProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.ModelProto"; + } + protected: + explicit ModelProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDomainFieldNumber = 2, + kGraphFieldNumber = 4, + kMetadataOperatorsFieldNumber = 5, + kIrVersionFieldNumber = 1, + kModelVersionFieldNumber = 3, + }; + // optional string domain = 2; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // optional .mindspore.irpb.GraphProto graph = 4; + bool has_graph() const; + private: + bool _internal_has_graph() const; + public: + void clear_graph(); + const ::mindspore::irpb::GraphProto& graph() const; + ::mindspore::irpb::GraphProto* release_graph(); + ::mindspore::irpb::GraphProto* mutable_graph(); + void set_allocated_graph(::mindspore::irpb::GraphProto* graph); + private: + const ::mindspore::irpb::GraphProto& _internal_graph() const; + ::mindspore::irpb::GraphProto* _internal_mutable_graph(); + public: + void unsafe_arena_set_allocated_graph( + ::mindspore::irpb::GraphProto* graph); + ::mindspore::irpb::GraphProto* unsafe_arena_release_graph(); + + // optional .mindspore.irpb.OperatorSetProto metadata_operators = 5; + bool has_metadata_operators() const; + private: + bool _internal_has_metadata_operators() const; + public: + void clear_metadata_operators(); + const ::mindspore::irpb::OperatorSetProto& metadata_operators() const; + ::mindspore::irpb::OperatorSetProto* release_metadata_operators(); + ::mindspore::irpb::OperatorSetProto* mutable_metadata_operators(); + void set_allocated_metadata_operators(::mindspore::irpb::OperatorSetProto* metadata_operators); + private: + const ::mindspore::irpb::OperatorSetProto& _internal_metadata_operators() const; + ::mindspore::irpb::OperatorSetProto* _internal_mutable_metadata_operators(); + public: + void unsafe_arena_set_allocated_metadata_operators( + ::mindspore::irpb::OperatorSetProto* metadata_operators); + ::mindspore::irpb::OperatorSetProto* unsafe_arena_release_metadata_operators(); + + // optional int64 ir_version = 1; + bool has_ir_version() const; + private: + bool _internal_has_ir_version() const; + public: + void clear_ir_version(); + ::PROTOBUF_NAMESPACE_ID::int64 ir_version() const; + void set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_ir_version() const; + void _internal_set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 model_version = 3; + bool has_model_version() const; + private: + bool _internal_has_model_version() const; + public: + void clear_model_version(); + ::PROTOBUF_NAMESPACE_ID::int64 model_version() const; + void set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_model_version() const; + void _internal_set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.ModelProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + ::mindspore::irpb::GraphProto* graph_; + ::mindspore::irpb::OperatorSetProto* metadata_operators_; + ::PROTOBUF_NAMESPACE_ID::int64 ir_version_; + ::PROTOBUF_NAMESPACE_ID::int64 model_version_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class OperatorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.OperatorProto) */ { + public: + inline OperatorProto() : OperatorProto(nullptr) {} + virtual ~OperatorProto(); + + OperatorProto(const OperatorProto& from); + OperatorProto(OperatorProto&& from) noexcept + : OperatorProto() { + *this = ::std::move(from); + } + + inline OperatorProto& operator=(const OperatorProto& from) { + CopyFrom(from); + return *this; + } + inline OperatorProto& operator=(OperatorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OperatorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OperatorProto* internal_default_instance() { + return reinterpret_cast( + &_OperatorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 13; + + friend void swap(OperatorProto& a, OperatorProto& b) { + a.Swap(&b); + } + inline void Swap(OperatorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(OperatorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OperatorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + OperatorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OperatorProto& from); + void MergeFrom(const OperatorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OperatorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.OperatorProto"; + } + protected: + explicit OperatorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kConfigFieldNumber = 2, + kObjInfoFieldNumber = 3, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional bytes config = 2; + bool has_config() const; + private: + bool _internal_has_config() const; + public: + void clear_config(); + const std::string& config() const; + void set_config(const std::string& value); + void set_config(std::string&& value); + void set_config(const char* value); + void set_config(const void* value, size_t size); + std::string* mutable_config(); + std::string* release_config(); + void set_allocated_config(std::string* config); + private: + const std::string& _internal_config() const; + void _internal_set_config(const std::string& value); + std::string* _internal_mutable_config(); + public: + + // optional bytes obj_info = 3; + bool has_obj_info() const; + private: + bool _internal_has_obj_info() const; + public: + void clear_obj_info(); + const std::string& obj_info() const; + void set_obj_info(const std::string& value); + void set_obj_info(std::string&& value); + void set_obj_info(const char* value); + void set_obj_info(const void* value, size_t size); + std::string* mutable_obj_info(); + std::string* release_obj_info(); + void set_allocated_obj_info(std::string* obj_info); + private: + const std::string& _internal_obj_info() const; + void _internal_set_obj_info(const std::string& value); + std::string* _internal_mutable_obj_info(); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.OperatorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr config_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr obj_info_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class OperatorSetProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.OperatorSetProto) */ { + public: + inline OperatorSetProto() : OperatorSetProto(nullptr) {} + virtual ~OperatorSetProto(); + + OperatorSetProto(const OperatorSetProto& from); + OperatorSetProto(OperatorSetProto&& from) noexcept + : OperatorSetProto() { + *this = ::std::move(from); + } + + inline OperatorSetProto& operator=(const OperatorSetProto& from) { + CopyFrom(from); + return *this; + } + inline OperatorSetProto& operator=(OperatorSetProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OperatorSetProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OperatorSetProto* internal_default_instance() { + return reinterpret_cast( + &_OperatorSetProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 14; + + friend void swap(OperatorSetProto& a, OperatorSetProto& b) { + a.Swap(&b); + } + inline void Swap(OperatorSetProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(OperatorSetProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OperatorSetProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + OperatorSetProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OperatorSetProto& from); + void MergeFrom(const OperatorSetProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OperatorSetProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.OperatorSetProto"; + } + protected: + explicit OperatorSetProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kOperatorsFieldNumber = 1, + kSummaryFieldNumber = 2, + }; + // repeated .mindspore.irpb.OperatorProto operators = 1; + int operators_size() const; + private: + int _internal_operators_size() const; + public: + void clear_operators(); + ::mindspore::irpb::OperatorProto* mutable_operators(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OperatorProto >* + mutable_operators(); + private: + const ::mindspore::irpb::OperatorProto& _internal_operators(int index) const; + ::mindspore::irpb::OperatorProto* _internal_add_operators(); + public: + const ::mindspore::irpb::OperatorProto& operators(int index) const; + ::mindspore::irpb::OperatorProto* add_operators(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OperatorProto >& + operators() const; + + // optional string summary = 2; + bool has_summary() const; + private: + bool _internal_has_summary() const; + public: + void clear_summary(); + const std::string& summary() const; + void set_summary(const std::string& value); + void set_summary(std::string&& value); + void set_summary(const char* value); + void set_summary(const char* value, size_t size); + std::string* mutable_summary(); + std::string* release_summary(); + void set_allocated_summary(std::string* summary); + private: + const std::string& _internal_summary() const; + void _internal_set_summary(const std::string& value); + std::string* _internal_mutable_summary(); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.OperatorSetProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OperatorProto > operators_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr summary_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class GraphProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.GraphProto) */ { + public: + inline GraphProto() : GraphProto(nullptr) {} + virtual ~GraphProto(); + + GraphProto(const GraphProto& from); + GraphProto(GraphProto&& from) noexcept + : GraphProto() { + *this = ::std::move(from); + } + + inline GraphProto& operator=(const GraphProto& from) { + CopyFrom(from); + return *this; + } + inline GraphProto& operator=(GraphProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const GraphProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const GraphProto* internal_default_instance() { + return reinterpret_cast( + &_GraphProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 15; + + friend void swap(GraphProto& a, GraphProto& b) { + a.Swap(&b); + } + inline void Swap(GraphProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(GraphProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline GraphProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + GraphProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const GraphProto& from); + void MergeFrom(const GraphProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(GraphProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.GraphProto"; + } + protected: + explicit GraphProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNodeFieldNumber = 1, + kParametersFieldNumber = 3, + kOutputsFieldNumber = 4, + kConstValsFieldNumber = 5, + kNameFieldNumber = 2, + }; + // repeated .mindspore.irpb.NodeProto node = 1; + int node_size() const; + private: + int _internal_node_size() const; + public: + void clear_node(); + ::mindspore::irpb::NodeProto* mutable_node(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NodeProto >* + mutable_node(); + private: + const ::mindspore::irpb::NodeProto& _internal_node(int index) const; + ::mindspore::irpb::NodeProto* _internal_add_node(); + public: + const ::mindspore::irpb::NodeProto& node(int index) const; + ::mindspore::irpb::NodeProto* add_node(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NodeProto >& + node() const; + + // repeated .mindspore.irpb.ParameterProto parameters = 3; + int parameters_size() const; + private: + int _internal_parameters_size() const; + public: + void clear_parameters(); + ::mindspore::irpb::ParameterProto* mutable_parameters(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ParameterProto >* + mutable_parameters(); + private: + const ::mindspore::irpb::ParameterProto& _internal_parameters(int index) const; + ::mindspore::irpb::ParameterProto* _internal_add_parameters(); + public: + const ::mindspore::irpb::ParameterProto& parameters(int index) const; + ::mindspore::irpb::ParameterProto* add_parameters(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ParameterProto >& + parameters() const; + + // repeated .mindspore.irpb.OutputProto outputs = 4; + int outputs_size() const; + private: + int _internal_outputs_size() const; + public: + void clear_outputs(); + ::mindspore::irpb::OutputProto* mutable_outputs(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OutputProto >* + mutable_outputs(); + private: + const ::mindspore::irpb::OutputProto& _internal_outputs(int index) const; + ::mindspore::irpb::OutputProto* _internal_add_outputs(); + public: + const ::mindspore::irpb::OutputProto& outputs(int index) const; + ::mindspore::irpb::OutputProto* add_outputs(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OutputProto >& + outputs() const; + + // repeated .mindspore.irpb.NamedValueProto const_vals = 5; + int const_vals_size() const; + private: + int _internal_const_vals_size() const; + public: + void clear_const_vals(); + ::mindspore::irpb::NamedValueProto* mutable_const_vals(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >* + mutable_const_vals(); + private: + const ::mindspore::irpb::NamedValueProto& _internal_const_vals(int index) const; + ::mindspore::irpb::NamedValueProto* _internal_add_const_vals(); + public: + const ::mindspore::irpb::NamedValueProto& const_vals(int index) const; + ::mindspore::irpb::NamedValueProto* add_const_vals(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >& + const_vals() const; + + // optional string name = 2; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.GraphProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NodeProto > node_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ParameterProto > parameters_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OutputProto > outputs_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto > const_vals_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.TensorProto) */ { + public: + inline TensorProto() : TensorProto(nullptr) {} + virtual ~TensorProto(); + + TensorProto(const TensorProto& from); + TensorProto(TensorProto&& from) noexcept + : TensorProto() { + *this = ::std::move(from); + } + + inline TensorProto& operator=(const TensorProto& from) { + CopyFrom(from); + return *this; + } + inline TensorProto& operator=(TensorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorProto* internal_default_instance() { + return reinterpret_cast( + &_TensorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 16; + + friend void swap(TensorProto& a, TensorProto& b) { + a.Swap(&b); + } + inline void Swap(TensorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorProto& from); + void MergeFrom(const TensorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.TensorProto"; + } + protected: + explicit TensorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fanf_5fir_2eproto); + return ::descriptor_table_mindspore_5fanf_5fir_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDimsFieldNumber = 1, + kFloatDataFieldNumber = 3, + kInt32DataFieldNumber = 4, + kInt64DataFieldNumber = 5, + kDoubleDataFieldNumber = 6, + kUint64DataFieldNumber = 7, + kRawDataFieldNumber = 8, + kDataTypeFieldNumber = 2, + }; + // repeated int64 dims = 1; + int dims_size() const; + private: + int _internal_dims_size() const; + public: + void clear_dims(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_dims(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_dims() const; + void _internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_dims(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 dims(int index) const; + void set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + dims() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_dims(); + + // repeated float float_data = 3 [packed = true]; + int float_data_size() const; + private: + int _internal_float_data_size() const; + public: + void clear_float_data(); + private: + float _internal_float_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_float_data() const; + void _internal_add_float_data(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_float_data(); + public: + float float_data(int index) const; + void set_float_data(int index, float value); + void add_float_data(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + float_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_float_data(); + + // repeated int32 int32_data = 4 [packed = true]; + int int32_data_size() const; + private: + int _internal_int32_data_size() const; + public: + void clear_int32_data(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_int32_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_int32_data() const; + void _internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_int32_data(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 int32_data(int index) const; + void set_int32_data(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + int32_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_int32_data(); + + // repeated int64 int64_data = 5 [packed = true]; + int int64_data_size() const; + private: + int _internal_int64_data_size() const; + public: + void clear_int64_data(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_int64_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_int64_data() const; + void _internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_int64_data(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 int64_data(int index) const; + void set_int64_data(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + int64_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_int64_data(); + + // repeated double double_data = 6 [packed = true]; + int double_data_size() const; + private: + int _internal_double_data_size() const; + public: + void clear_double_data(); + private: + double _internal_double_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_double_data() const; + void _internal_add_double_data(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_double_data(); + public: + double double_data(int index) const; + void set_double_data(int index, double value); + void add_double_data(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + double_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_double_data(); + + // repeated uint64 uint64_data = 7 [packed = true]; + int uint64_data_size() const; + private: + int _internal_uint64_data_size() const; + public: + void clear_uint64_data(); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_uint64_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + _internal_uint64_data() const; + void _internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + _internal_mutable_uint64_data(); + public: + ::PROTOBUF_NAMESPACE_ID::uint64 uint64_data(int index) const; + void set_uint64_data(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value); + void add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + uint64_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + mutable_uint64_data(); + + // optional bytes raw_data = 8; + bool has_raw_data() const; + private: + bool _internal_has_raw_data() const; + public: + void clear_raw_data(); + const std::string& raw_data() const; + void set_raw_data(const std::string& value); + void set_raw_data(std::string&& value); + void set_raw_data(const char* value); + void set_raw_data(const void* value, size_t size); + std::string* mutable_raw_data(); + std::string* release_raw_data(); + void set_allocated_raw_data(std::string* raw_data); + private: + const std::string& _internal_raw_data() const; + void _internal_set_raw_data(const std::string& value); + std::string* _internal_mutable_raw_data(); + public: + + // optional .mindspore.irpb.DataType data_type = 2; + bool has_data_type() const; + private: + bool _internal_has_data_type() const; + public: + void clear_data_type(); + ::mindspore::irpb::DataType data_type() const; + void set_data_type(::mindspore::irpb::DataType value); + private: + ::mindspore::irpb::DataType _internal_data_type() const; + void _internal_set_data_type(::mindspore::irpb::DataType value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.TensorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > dims_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_data_; + mutable std::atomic _float_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > int32_data_; + mutable std::atomic _int32_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > int64_data_; + mutable std::atomic _int64_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > double_data_; + mutable std::atomic _double_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 > uint64_data_; + mutable std::atomic _uint64_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr raw_data_; + int data_type_; + friend struct ::TableStruct_mindspore_5fanf_5fir_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// ValueProto + +// optional .mindspore.irpb.DataType dtype = 1; +inline bool ValueProto::_internal_has_dtype() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool ValueProto::has_dtype() const { + return _internal_has_dtype(); +} +inline void ValueProto::clear_dtype() { + dtype_ = 0; + _has_bits_[0] &= ~0x00000010u; +} +inline ::mindspore::irpb::DataType ValueProto::_internal_dtype() const { + return static_cast< ::mindspore::irpb::DataType >(dtype_); +} +inline ::mindspore::irpb::DataType ValueProto::dtype() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.dtype) + return _internal_dtype(); +} +inline void ValueProto::_internal_set_dtype(::mindspore::irpb::DataType value) { + assert(::mindspore::irpb::DataType_IsValid(value)); + _has_bits_[0] |= 0x00000010u; + dtype_ = value; +} +inline void ValueProto::set_dtype(::mindspore::irpb::DataType value) { + _internal_set_dtype(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.dtype) +} + +// optional bool bool_val = 2; +inline bool ValueProto::_internal_has_bool_val() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool ValueProto::has_bool_val() const { + return _internal_has_bool_val(); +} +inline void ValueProto::clear_bool_val() { + bool_val_ = false; + _has_bits_[0] &= ~0x00000020u; +} +inline bool ValueProto::_internal_bool_val() const { + return bool_val_; +} +inline bool ValueProto::bool_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.bool_val) + return _internal_bool_val(); +} +inline void ValueProto::_internal_set_bool_val(bool value) { + _has_bits_[0] |= 0x00000020u; + bool_val_ = value; +} +inline void ValueProto::set_bool_val(bool value) { + _internal_set_bool_val(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.bool_val) +} + +// optional int64 int_val = 3; +inline bool ValueProto::_internal_has_int_val() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool ValueProto::has_int_val() const { + return _internal_has_int_val(); +} +inline void ValueProto::clear_int_val() { + int_val_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ValueProto::_internal_int_val() const { + return int_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ValueProto::int_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.int_val) + return _internal_int_val(); +} +inline void ValueProto::_internal_set_int_val(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000040u; + int_val_ = value; +} +inline void ValueProto::set_int_val(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_int_val(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.int_val) +} + +// optional uint64 uint_val = 4; +inline bool ValueProto::_internal_has_uint_val() const { + bool value = (_has_bits_[0] & 0x00000080u) != 0; + return value; +} +inline bool ValueProto::has_uint_val() const { + return _internal_has_uint_val(); +} +inline void ValueProto::clear_uint_val() { + uint_val_ = PROTOBUF_ULONGLONG(0); + _has_bits_[0] &= ~0x00000080u; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 ValueProto::_internal_uint_val() const { + return uint_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 ValueProto::uint_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.uint_val) + return _internal_uint_val(); +} +inline void ValueProto::_internal_set_uint_val(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _has_bits_[0] |= 0x00000080u; + uint_val_ = value; +} +inline void ValueProto::set_uint_val(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_uint_val(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.uint_val) +} + +// optional float float_val = 5; +inline bool ValueProto::_internal_has_float_val() const { + bool value = (_has_bits_[0] & 0x00000200u) != 0; + return value; +} +inline bool ValueProto::has_float_val() const { + return _internal_has_float_val(); +} +inline void ValueProto::clear_float_val() { + float_val_ = 0; + _has_bits_[0] &= ~0x00000200u; +} +inline float ValueProto::_internal_float_val() const { + return float_val_; +} +inline float ValueProto::float_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.float_val) + return _internal_float_val(); +} +inline void ValueProto::_internal_set_float_val(float value) { + _has_bits_[0] |= 0x00000200u; + float_val_ = value; +} +inline void ValueProto::set_float_val(float value) { + _internal_set_float_val(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.float_val) +} + +// optional double double_val = 6; +inline bool ValueProto::_internal_has_double_val() const { + bool value = (_has_bits_[0] & 0x00000100u) != 0; + return value; +} +inline bool ValueProto::has_double_val() const { + return _internal_has_double_val(); +} +inline void ValueProto::clear_double_val() { + double_val_ = 0; + _has_bits_[0] &= ~0x00000100u; +} +inline double ValueProto::_internal_double_val() const { + return double_val_; +} +inline double ValueProto::double_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.double_val) + return _internal_double_val(); +} +inline void ValueProto::_internal_set_double_val(double value) { + _has_bits_[0] |= 0x00000100u; + double_val_ = value; +} +inline void ValueProto::set_double_val(double value) { + _internal_set_double_val(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.double_val) +} + +// optional string str_val = 7; +inline bool ValueProto::_internal_has_str_val() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ValueProto::has_str_val() const { + return _internal_has_str_val(); +} +inline void ValueProto::clear_str_val() { + str_val_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ValueProto::str_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.str_val) + return _internal_str_val(); +} +inline void ValueProto::set_str_val(const std::string& value) { + _internal_set_str_val(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.str_val) +} +inline std::string* ValueProto::mutable_str_val() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.str_val) + return _internal_mutable_str_val(); +} +inline const std::string& ValueProto::_internal_str_val() const { + return str_val_.Get(); +} +inline void ValueProto::_internal_set_str_val(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + str_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ValueProto::set_str_val(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + str_val_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.ValueProto.str_val) +} +inline void ValueProto::set_str_val(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + str_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.ValueProto.str_val) +} +inline void ValueProto::set_str_val(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + str_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.ValueProto.str_val) +} +inline std::string* ValueProto::_internal_mutable_str_val() { + _has_bits_[0] |= 0x00000001u; + return str_val_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ValueProto::release_str_val() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ValueProto.str_val) + if (!_internal_has_str_val()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return str_val_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ValueProto::set_allocated_str_val(std::string* str_val) { + if (str_val != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + str_val_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), str_val, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ValueProto.str_val) +} + +// optional .mindspore.irpb.TensorProto tensor_val = 8; +inline bool ValueProto::_internal_has_tensor_val() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || tensor_val_ != nullptr); + return value; +} +inline bool ValueProto::has_tensor_val() const { + return _internal_has_tensor_val(); +} +inline void ValueProto::clear_tensor_val() { + if (tensor_val_ != nullptr) tensor_val_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::TensorProto& ValueProto::_internal_tensor_val() const { + const ::mindspore::irpb::TensorProto* p = tensor_val_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TensorProto_default_instance_); +} +inline const ::mindspore::irpb::TensorProto& ValueProto::tensor_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.tensor_val) + return _internal_tensor_val(); +} +inline void ValueProto::unsafe_arena_set_allocated_tensor_val( + ::mindspore::irpb::TensorProto* tensor_val) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(tensor_val_); + } + tensor_val_ = tensor_val; + if (tensor_val) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ValueProto.tensor_val) +} +inline ::mindspore::irpb::TensorProto* ValueProto::release_tensor_val() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TensorProto* temp = tensor_val_; + tensor_val_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TensorProto* ValueProto::unsafe_arena_release_tensor_val() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ValueProto.tensor_val) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TensorProto* temp = tensor_val_; + tensor_val_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TensorProto* ValueProto::_internal_mutable_tensor_val() { + _has_bits_[0] |= 0x00000002u; + if (tensor_val_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TensorProto>(GetArena()); + tensor_val_ = p; + } + return tensor_val_; +} +inline ::mindspore::irpb::TensorProto* ValueProto::mutable_tensor_val() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.tensor_val) + return _internal_mutable_tensor_val(); +} +inline void ValueProto::set_allocated_tensor_val(::mindspore::irpb::TensorProto* tensor_val) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete tensor_val_; + } + if (tensor_val) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(tensor_val); + if (message_arena != submessage_arena) { + tensor_val = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tensor_val, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + tensor_val_ = tensor_val; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ValueProto.tensor_val) +} + +// optional .mindspore.irpb.GraphProto graph = 9; +inline bool ValueProto::_internal_has_graph() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || graph_ != nullptr); + return value; +} +inline bool ValueProto::has_graph() const { + return _internal_has_graph(); +} +inline void ValueProto::clear_graph() { + if (graph_ != nullptr) graph_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::mindspore::irpb::GraphProto& ValueProto::_internal_graph() const { + const ::mindspore::irpb::GraphProto* p = graph_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_GraphProto_default_instance_); +} +inline const ::mindspore::irpb::GraphProto& ValueProto::graph() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.graph) + return _internal_graph(); +} +inline void ValueProto::unsafe_arena_set_allocated_graph( + ::mindspore::irpb::GraphProto* graph) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(graph_); + } + graph_ = graph; + if (graph) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ValueProto.graph) +} +inline ::mindspore::irpb::GraphProto* ValueProto::release_graph() { + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::GraphProto* temp = graph_; + graph_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::GraphProto* ValueProto::unsafe_arena_release_graph() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ValueProto.graph) + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::GraphProto* temp = graph_; + graph_ = nullptr; + return temp; +} +inline ::mindspore::irpb::GraphProto* ValueProto::_internal_mutable_graph() { + _has_bits_[0] |= 0x00000004u; + if (graph_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::GraphProto>(GetArena()); + graph_ = p; + } + return graph_; +} +inline ::mindspore::irpb::GraphProto* ValueProto::mutable_graph() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.graph) + return _internal_mutable_graph(); +} +inline void ValueProto::set_allocated_graph(::mindspore::irpb::GraphProto* graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete graph_; + } + if (graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(graph); + if (message_arena != submessage_arena) { + graph = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, graph, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + graph_ = graph; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ValueProto.graph) +} + +// repeated bool bool_vals = 10; +inline int ValueProto::_internal_bool_vals_size() const { + return bool_vals_.size(); +} +inline int ValueProto::bool_vals_size() const { + return _internal_bool_vals_size(); +} +inline void ValueProto::clear_bool_vals() { + bool_vals_.Clear(); +} +inline bool ValueProto::_internal_bool_vals(int index) const { + return bool_vals_.Get(index); +} +inline bool ValueProto::bool_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.bool_vals) + return _internal_bool_vals(index); +} +inline void ValueProto::set_bool_vals(int index, bool value) { + bool_vals_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.bool_vals) +} +inline void ValueProto::_internal_add_bool_vals(bool value) { + bool_vals_.Add(value); +} +inline void ValueProto::add_bool_vals(bool value) { + _internal_add_bool_vals(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.bool_vals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& +ValueProto::_internal_bool_vals() const { + return bool_vals_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& +ValueProto::bool_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.bool_vals) + return _internal_bool_vals(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* +ValueProto::_internal_mutable_bool_vals() { + return &bool_vals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* +ValueProto::mutable_bool_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.bool_vals) + return _internal_mutable_bool_vals(); +} + +// repeated int64 int_vals = 11; +inline int ValueProto::_internal_int_vals_size() const { + return int_vals_.size(); +} +inline int ValueProto::int_vals_size() const { + return _internal_int_vals_size(); +} +inline void ValueProto::clear_int_vals() { + int_vals_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ValueProto::_internal_int_vals(int index) const { + return int_vals_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ValueProto::int_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.int_vals) + return _internal_int_vals(index); +} +inline void ValueProto::set_int_vals(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + int_vals_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.int_vals) +} +inline void ValueProto::_internal_add_int_vals(::PROTOBUF_NAMESPACE_ID::int64 value) { + int_vals_.Add(value); +} +inline void ValueProto::add_int_vals(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_int_vals(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.int_vals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +ValueProto::_internal_int_vals() const { + return int_vals_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +ValueProto::int_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.int_vals) + return _internal_int_vals(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +ValueProto::_internal_mutable_int_vals() { + return &int_vals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +ValueProto::mutable_int_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.int_vals) + return _internal_mutable_int_vals(); +} + +// repeated uint64 uint_vals = 12; +inline int ValueProto::_internal_uint_vals_size() const { + return uint_vals_.size(); +} +inline int ValueProto::uint_vals_size() const { + return _internal_uint_vals_size(); +} +inline void ValueProto::clear_uint_vals() { + uint_vals_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 ValueProto::_internal_uint_vals(int index) const { + return uint_vals_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 ValueProto::uint_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.uint_vals) + return _internal_uint_vals(index); +} +inline void ValueProto::set_uint_vals(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint_vals_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.uint_vals) +} +inline void ValueProto::_internal_add_uint_vals(::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint_vals_.Add(value); +} +inline void ValueProto::add_uint_vals(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_add_uint_vals(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.uint_vals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +ValueProto::_internal_uint_vals() const { + return uint_vals_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +ValueProto::uint_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.uint_vals) + return _internal_uint_vals(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +ValueProto::_internal_mutable_uint_vals() { + return &uint_vals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +ValueProto::mutable_uint_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.uint_vals) + return _internal_mutable_uint_vals(); +} + +// repeated float float_vals = 13; +inline int ValueProto::_internal_float_vals_size() const { + return float_vals_.size(); +} +inline int ValueProto::float_vals_size() const { + return _internal_float_vals_size(); +} +inline void ValueProto::clear_float_vals() { + float_vals_.Clear(); +} +inline float ValueProto::_internal_float_vals(int index) const { + return float_vals_.Get(index); +} +inline float ValueProto::float_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.float_vals) + return _internal_float_vals(index); +} +inline void ValueProto::set_float_vals(int index, float value) { + float_vals_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.float_vals) +} +inline void ValueProto::_internal_add_float_vals(float value) { + float_vals_.Add(value); +} +inline void ValueProto::add_float_vals(float value) { + _internal_add_float_vals(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.float_vals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +ValueProto::_internal_float_vals() const { + return float_vals_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +ValueProto::float_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.float_vals) + return _internal_float_vals(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +ValueProto::_internal_mutable_float_vals() { + return &float_vals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +ValueProto::mutable_float_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.float_vals) + return _internal_mutable_float_vals(); +} + +// repeated double double_vals = 14; +inline int ValueProto::_internal_double_vals_size() const { + return double_vals_.size(); +} +inline int ValueProto::double_vals_size() const { + return _internal_double_vals_size(); +} +inline void ValueProto::clear_double_vals() { + double_vals_.Clear(); +} +inline double ValueProto::_internal_double_vals(int index) const { + return double_vals_.Get(index); +} +inline double ValueProto::double_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.double_vals) + return _internal_double_vals(index); +} +inline void ValueProto::set_double_vals(int index, double value) { + double_vals_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.double_vals) +} +inline void ValueProto::_internal_add_double_vals(double value) { + double_vals_.Add(value); +} +inline void ValueProto::add_double_vals(double value) { + _internal_add_double_vals(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.double_vals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +ValueProto::_internal_double_vals() const { + return double_vals_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +ValueProto::double_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.double_vals) + return _internal_double_vals(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +ValueProto::_internal_mutable_double_vals() { + return &double_vals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +ValueProto::mutable_double_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.double_vals) + return _internal_mutable_double_vals(); +} + +// repeated string str_vals = 15; +inline int ValueProto::_internal_str_vals_size() const { + return str_vals_.size(); +} +inline int ValueProto::str_vals_size() const { + return _internal_str_vals_size(); +} +inline void ValueProto::clear_str_vals() { + str_vals_.Clear(); +} +inline std::string* ValueProto::add_str_vals() { + // @@protoc_insertion_point(field_add_mutable:mindspore.irpb.ValueProto.str_vals) + return _internal_add_str_vals(); +} +inline const std::string& ValueProto::_internal_str_vals(int index) const { + return str_vals_.Get(index); +} +inline const std::string& ValueProto::str_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.str_vals) + return _internal_str_vals(index); +} +inline std::string* ValueProto::mutable_str_vals(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.str_vals) + return str_vals_.Mutable(index); +} +inline void ValueProto::set_str_vals(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.str_vals) + str_vals_.Mutable(index)->assign(value); +} +inline void ValueProto::set_str_vals(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.ValueProto.str_vals) + str_vals_.Mutable(index)->assign(std::move(value)); +} +inline void ValueProto::set_str_vals(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + str_vals_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.ValueProto.str_vals) +} +inline void ValueProto::set_str_vals(int index, const char* value, size_t size) { + str_vals_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.ValueProto.str_vals) +} +inline std::string* ValueProto::_internal_add_str_vals() { + return str_vals_.Add(); +} +inline void ValueProto::add_str_vals(const std::string& value) { + str_vals_.Add()->assign(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.str_vals) +} +inline void ValueProto::add_str_vals(std::string&& value) { + str_vals_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.str_vals) +} +inline void ValueProto::add_str_vals(const char* value) { + GOOGLE_DCHECK(value != nullptr); + str_vals_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:mindspore.irpb.ValueProto.str_vals) +} +inline void ValueProto::add_str_vals(const char* value, size_t size) { + str_vals_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:mindspore.irpb.ValueProto.str_vals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +ValueProto::str_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.str_vals) + return str_vals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +ValueProto::mutable_str_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.str_vals) + return &str_vals_; +} + +// repeated .mindspore.irpb.TensorProto tensor_vals = 16; +inline int ValueProto::_internal_tensor_vals_size() const { + return tensor_vals_.size(); +} +inline int ValueProto::tensor_vals_size() const { + return _internal_tensor_vals_size(); +} +inline void ValueProto::clear_tensor_vals() { + tensor_vals_.Clear(); +} +inline ::mindspore::irpb::TensorProto* ValueProto::mutable_tensor_vals(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.tensor_vals) + return tensor_vals_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorProto >* +ValueProto::mutable_tensor_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.tensor_vals) + return &tensor_vals_; +} +inline const ::mindspore::irpb::TensorProto& ValueProto::_internal_tensor_vals(int index) const { + return tensor_vals_.Get(index); +} +inline const ::mindspore::irpb::TensorProto& ValueProto::tensor_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.tensor_vals) + return _internal_tensor_vals(index); +} +inline ::mindspore::irpb::TensorProto* ValueProto::_internal_add_tensor_vals() { + return tensor_vals_.Add(); +} +inline ::mindspore::irpb::TensorProto* ValueProto::add_tensor_vals() { + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.tensor_vals) + return _internal_add_tensor_vals(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorProto >& +ValueProto::tensor_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.tensor_vals) + return tensor_vals_; +} + +// repeated .mindspore.irpb.GraphProto graphs = 17; +inline int ValueProto::_internal_graphs_size() const { + return graphs_.size(); +} +inline int ValueProto::graphs_size() const { + return _internal_graphs_size(); +} +inline void ValueProto::clear_graphs() { + graphs_.Clear(); +} +inline ::mindspore::irpb::GraphProto* ValueProto::mutable_graphs(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.graphs) + return graphs_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::GraphProto >* +ValueProto::mutable_graphs() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.graphs) + return &graphs_; +} +inline const ::mindspore::irpb::GraphProto& ValueProto::_internal_graphs(int index) const { + return graphs_.Get(index); +} +inline const ::mindspore::irpb::GraphProto& ValueProto::graphs(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.graphs) + return _internal_graphs(index); +} +inline ::mindspore::irpb::GraphProto* ValueProto::_internal_add_graphs() { + return graphs_.Add(); +} +inline ::mindspore::irpb::GraphProto* ValueProto::add_graphs() { + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.graphs) + return _internal_add_graphs(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::GraphProto >& +ValueProto::graphs() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.graphs) + return graphs_; +} + +// repeated .mindspore.irpb.ValueProto values = 18; +inline int ValueProto::_internal_values_size() const { + return values_.size(); +} +inline int ValueProto::values_size() const { + return _internal_values_size(); +} +inline void ValueProto::clear_values() { + values_.Clear(); +} +inline ::mindspore::irpb::ValueProto* ValueProto::mutable_values(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.values) + return values_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ValueProto >* +ValueProto::mutable_values() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.values) + return &values_; +} +inline const ::mindspore::irpb::ValueProto& ValueProto::_internal_values(int index) const { + return values_.Get(index); +} +inline const ::mindspore::irpb::ValueProto& ValueProto::values(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.values) + return _internal_values(index); +} +inline ::mindspore::irpb::ValueProto* ValueProto::_internal_add_values() { + return values_.Add(); +} +inline ::mindspore::irpb::ValueProto* ValueProto::add_values() { + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.values) + return _internal_add_values(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ValueProto >& +ValueProto::values() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.values) + return values_; +} + +// repeated .mindspore.irpb.NamedValueProto dict_val = 19; +inline int ValueProto::_internal_dict_val_size() const { + return dict_val_.size(); +} +inline int ValueProto::dict_val_size() const { + return _internal_dict_val_size(); +} +inline void ValueProto::clear_dict_val() { + dict_val_.Clear(); +} +inline ::mindspore::irpb::NamedValueProto* ValueProto::mutable_dict_val(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.dict_val) + return dict_val_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >* +ValueProto::mutable_dict_val() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.ValueProto.dict_val) + return &dict_val_; +} +inline const ::mindspore::irpb::NamedValueProto& ValueProto::_internal_dict_val(int index) const { + return dict_val_.Get(index); +} +inline const ::mindspore::irpb::NamedValueProto& ValueProto::dict_val(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.dict_val) + return _internal_dict_val(index); +} +inline ::mindspore::irpb::NamedValueProto* ValueProto::_internal_add_dict_val() { + return dict_val_.Add(); +} +inline ::mindspore::irpb::NamedValueProto* ValueProto::add_dict_val() { + // @@protoc_insertion_point(field_add:mindspore.irpb.ValueProto.dict_val) + return _internal_add_dict_val(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >& +ValueProto::dict_val() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.ValueProto.dict_val) + return dict_val_; +} + +// optional .mindspore.irpb.TypeProto type_val = 20; +inline bool ValueProto::_internal_has_type_val() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || type_val_ != nullptr); + return value; +} +inline bool ValueProto::has_type_val() const { + return _internal_has_type_val(); +} +inline void ValueProto::clear_type_val() { + if (type_val_ != nullptr) type_val_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const ::mindspore::irpb::TypeProto& ValueProto::_internal_type_val() const { + const ::mindspore::irpb::TypeProto* p = type_val_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TypeProto_default_instance_); +} +inline const ::mindspore::irpb::TypeProto& ValueProto::type_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ValueProto.type_val) + return _internal_type_val(); +} +inline void ValueProto::unsafe_arena_set_allocated_type_val( + ::mindspore::irpb::TypeProto* type_val) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(type_val_); + } + type_val_ = type_val; + if (type_val) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ValueProto.type_val) +} +inline ::mindspore::irpb::TypeProto* ValueProto::release_type_val() { + _has_bits_[0] &= ~0x00000008u; + ::mindspore::irpb::TypeProto* temp = type_val_; + type_val_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TypeProto* ValueProto::unsafe_arena_release_type_val() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ValueProto.type_val) + _has_bits_[0] &= ~0x00000008u; + ::mindspore::irpb::TypeProto* temp = type_val_; + type_val_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TypeProto* ValueProto::_internal_mutable_type_val() { + _has_bits_[0] |= 0x00000008u; + if (type_val_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TypeProto>(GetArena()); + type_val_ = p; + } + return type_val_; +} +inline ::mindspore::irpb::TypeProto* ValueProto::mutable_type_val() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ValueProto.type_val) + return _internal_mutable_type_val(); +} +inline void ValueProto::set_allocated_type_val(::mindspore::irpb::TypeProto* type_val) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete type_val_; + } + if (type_val) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(type_val); + if (message_arena != submessage_arena) { + type_val = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, type_val, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + type_val_ = type_val; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ValueProto.type_val) +} + +// ------------------------------------------------------------------- + +// AttributeProto + +// optional string name = 1; +inline bool AttributeProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool AttributeProto::has_name() const { + return _internal_has_name(); +} +inline void AttributeProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& AttributeProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.AttributeProto.name) + return _internal_name(); +} +inline void AttributeProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.AttributeProto.name) +} +inline std::string* AttributeProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.AttributeProto.name) + return _internal_mutable_name(); +} +inline const std::string& AttributeProto::_internal_name() const { + return name_.Get(); +} +inline void AttributeProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void AttributeProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.AttributeProto.name) +} +inline void AttributeProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.AttributeProto.name) +} +inline void AttributeProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.AttributeProto.name) +} +inline std::string* AttributeProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* AttributeProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.AttributeProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void AttributeProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.AttributeProto.name) +} + +// optional .mindspore.irpb.ValueProto value = 2; +inline bool AttributeProto::_internal_has_value() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || value_ != nullptr); + return value; +} +inline bool AttributeProto::has_value() const { + return _internal_has_value(); +} +inline void AttributeProto::clear_value() { + if (value_ != nullptr) value_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::ValueProto& AttributeProto::_internal_value() const { + const ::mindspore::irpb::ValueProto* p = value_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_ValueProto_default_instance_); +} +inline const ::mindspore::irpb::ValueProto& AttributeProto::value() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.AttributeProto.value) + return _internal_value(); +} +inline void AttributeProto::unsafe_arena_set_allocated_value( + ::mindspore::irpb::ValueProto* value) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(value_); + } + value_ = value; + if (value) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.AttributeProto.value) +} +inline ::mindspore::irpb::ValueProto* AttributeProto::release_value() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::ValueProto* temp = value_; + value_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::ValueProto* AttributeProto::unsafe_arena_release_value() { + // @@protoc_insertion_point(field_release:mindspore.irpb.AttributeProto.value) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::ValueProto* temp = value_; + value_ = nullptr; + return temp; +} +inline ::mindspore::irpb::ValueProto* AttributeProto::_internal_mutable_value() { + _has_bits_[0] |= 0x00000002u; + if (value_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::ValueProto>(GetArena()); + value_ = p; + } + return value_; +} +inline ::mindspore::irpb::ValueProto* AttributeProto::mutable_value() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.AttributeProto.value) + return _internal_mutable_value(); +} +inline void AttributeProto::set_allocated_value(::mindspore::irpb::ValueProto* value) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete value_; + } + if (value) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(value); + if (message_arena != submessage_arena) { + value = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, value, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + value_ = value; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.AttributeProto.value) +} + +// ------------------------------------------------------------------- + +// NamedValueProto + +// optional string key = 1; +inline bool NamedValueProto::_internal_has_key() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool NamedValueProto::has_key() const { + return _internal_has_key(); +} +inline void NamedValueProto::clear_key() { + key_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& NamedValueProto::key() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NamedValueProto.key) + return _internal_key(); +} +inline void NamedValueProto::set_key(const std::string& value) { + _internal_set_key(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NamedValueProto.key) +} +inline std::string* NamedValueProto::mutable_key() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NamedValueProto.key) + return _internal_mutable_key(); +} +inline const std::string& NamedValueProto::_internal_key() const { + return key_.Get(); +} +inline void NamedValueProto::_internal_set_key(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NamedValueProto::set_key(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + key_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.NamedValueProto.key) +} +inline void NamedValueProto::set_key(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.NamedValueProto.key) +} +inline void NamedValueProto::set_key(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.NamedValueProto.key) +} +inline std::string* NamedValueProto::_internal_mutable_key() { + _has_bits_[0] |= 0x00000001u; + return key_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NamedValueProto::release_key() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NamedValueProto.key) + if (!_internal_has_key()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return key_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NamedValueProto::set_allocated_key(std::string* key) { + if (key != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + key_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), key, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NamedValueProto.key) +} + +// optional .mindspore.irpb.ValueProto value = 2; +inline bool NamedValueProto::_internal_has_value() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || value_ != nullptr); + return value; +} +inline bool NamedValueProto::has_value() const { + return _internal_has_value(); +} +inline void NamedValueProto::clear_value() { + if (value_ != nullptr) value_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::ValueProto& NamedValueProto::_internal_value() const { + const ::mindspore::irpb::ValueProto* p = value_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_ValueProto_default_instance_); +} +inline const ::mindspore::irpb::ValueProto& NamedValueProto::value() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NamedValueProto.value) + return _internal_value(); +} +inline void NamedValueProto::unsafe_arena_set_allocated_value( + ::mindspore::irpb::ValueProto* value) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(value_); + } + value_ = value; + if (value) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.NamedValueProto.value) +} +inline ::mindspore::irpb::ValueProto* NamedValueProto::release_value() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::ValueProto* temp = value_; + value_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::ValueProto* NamedValueProto::unsafe_arena_release_value() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NamedValueProto.value) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::ValueProto* temp = value_; + value_ = nullptr; + return temp; +} +inline ::mindspore::irpb::ValueProto* NamedValueProto::_internal_mutable_value() { + _has_bits_[0] |= 0x00000002u; + if (value_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::ValueProto>(GetArena()); + value_ = p; + } + return value_; +} +inline ::mindspore::irpb::ValueProto* NamedValueProto::mutable_value() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NamedValueProto.value) + return _internal_mutable_value(); +} +inline void NamedValueProto::set_allocated_value(::mindspore::irpb::ValueProto* value) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete value_; + } + if (value) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(value); + if (message_arena != submessage_arena) { + value = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, value, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + value_ = value; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NamedValueProto.value) +} + +// ------------------------------------------------------------------- + +// TensorShapeProto_Dimension + +// optional int64 size = 1; +inline bool TensorShapeProto_Dimension::_internal_has_size() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TensorShapeProto_Dimension::has_size() const { + return _internal_has_size(); +} +inline void TensorShapeProto_Dimension::clear_size() { + size_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dimension::_internal_size() const { + return size_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dimension::size() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorShapeProto.Dimension.size) + return _internal_size(); +} +inline void TensorShapeProto_Dimension::_internal_set_size(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + size_ = value; +} +inline void TensorShapeProto_Dimension::set_size(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_size(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorShapeProto.Dimension.size) +} + +// optional string name = 2; +inline bool TensorShapeProto_Dimension::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorShapeProto_Dimension::has_name() const { + return _internal_has_name(); +} +inline void TensorShapeProto_Dimension::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorShapeProto_Dimension::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorShapeProto.Dimension.name) + return _internal_name(); +} +inline void TensorShapeProto_Dimension::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorShapeProto.Dimension.name) +} +inline std::string* TensorShapeProto_Dimension::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TensorShapeProto.Dimension.name) + return _internal_mutable_name(); +} +inline const std::string& TensorShapeProto_Dimension::_internal_name() const { + return name_.Get(); +} +inline void TensorShapeProto_Dimension::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorShapeProto_Dimension::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.TensorShapeProto.Dimension.name) +} +inline void TensorShapeProto_Dimension::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.TensorShapeProto.Dimension.name) +} +inline void TensorShapeProto_Dimension::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.TensorShapeProto.Dimension.name) +} +inline std::string* TensorShapeProto_Dimension::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorShapeProto_Dimension::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.TensorShapeProto.Dimension.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorShapeProto_Dimension::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.TensorShapeProto.Dimension.name) +} + +// ------------------------------------------------------------------- + +// TensorShapeProto + +// repeated .mindspore.irpb.TensorShapeProto.Dimension dim = 1; +inline int TensorShapeProto::_internal_dim_size() const { + return dim_.size(); +} +inline int TensorShapeProto::dim_size() const { + return _internal_dim_size(); +} +inline void TensorShapeProto::clear_dim() { + dim_.Clear(); +} +inline ::mindspore::irpb::TensorShapeProto_Dimension* TensorShapeProto::mutable_dim(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TensorShapeProto.dim) + return dim_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorShapeProto_Dimension >* +TensorShapeProto::mutable_dim() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorShapeProto.dim) + return &dim_; +} +inline const ::mindspore::irpb::TensorShapeProto_Dimension& TensorShapeProto::_internal_dim(int index) const { + return dim_.Get(index); +} +inline const ::mindspore::irpb::TensorShapeProto_Dimension& TensorShapeProto::dim(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorShapeProto.dim) + return _internal_dim(index); +} +inline ::mindspore::irpb::TensorShapeProto_Dimension* TensorShapeProto::_internal_add_dim() { + return dim_.Add(); +} +inline ::mindspore::irpb::TensorShapeProto_Dimension* TensorShapeProto::add_dim() { + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorShapeProto.dim) + return _internal_add_dim(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TensorShapeProto_Dimension >& +TensorShapeProto::dim() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorShapeProto.dim) + return dim_; +} + +// ------------------------------------------------------------------- + +// TypeProto_Tensor + +// optional .mindspore.irpb.DataType elem_type = 1; +inline bool TypeProto_Tensor::_internal_has_elem_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_Tensor::has_elem_type() const { + return _internal_has_elem_type(); +} +inline void TypeProto_Tensor::clear_elem_type() { + elem_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::mindspore::irpb::DataType TypeProto_Tensor::_internal_elem_type() const { + return static_cast< ::mindspore::irpb::DataType >(elem_type_); +} +inline ::mindspore::irpb::DataType TypeProto_Tensor::elem_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TypeProto.Tensor.elem_type) + return _internal_elem_type(); +} +inline void TypeProto_Tensor::_internal_set_elem_type(::mindspore::irpb::DataType value) { + assert(::mindspore::irpb::DataType_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + elem_type_ = value; +} +inline void TypeProto_Tensor::set_elem_type(::mindspore::irpb::DataType value) { + _internal_set_elem_type(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TypeProto.Tensor.elem_type) +} + +// optional .mindspore.irpb.TensorShapeProto shape = 2; +inline bool TypeProto_Tensor::_internal_has_shape() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || shape_ != nullptr); + return value; +} +inline bool TypeProto_Tensor::has_shape() const { + return _internal_has_shape(); +} +inline void TypeProto_Tensor::clear_shape() { + if (shape_ != nullptr) shape_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::mindspore::irpb::TensorShapeProto& TypeProto_Tensor::_internal_shape() const { + const ::mindspore::irpb::TensorShapeProto* p = shape_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TensorShapeProto_default_instance_); +} +inline const ::mindspore::irpb::TensorShapeProto& TypeProto_Tensor::shape() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TypeProto.Tensor.shape) + return _internal_shape(); +} +inline void TypeProto_Tensor::unsafe_arena_set_allocated_shape( + ::mindspore::irpb::TensorShapeProto* shape) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(shape_); + } + shape_ = shape; + if (shape) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.TypeProto.Tensor.shape) +} +inline ::mindspore::irpb::TensorShapeProto* TypeProto_Tensor::release_shape() { + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::TensorShapeProto* temp = shape_; + shape_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TensorShapeProto* TypeProto_Tensor::unsafe_arena_release_shape() { + // @@protoc_insertion_point(field_release:mindspore.irpb.TypeProto.Tensor.shape) + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::TensorShapeProto* temp = shape_; + shape_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TensorShapeProto* TypeProto_Tensor::_internal_mutable_shape() { + _has_bits_[0] |= 0x00000001u; + if (shape_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TensorShapeProto>(GetArena()); + shape_ = p; + } + return shape_; +} +inline ::mindspore::irpb::TensorShapeProto* TypeProto_Tensor::mutable_shape() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TypeProto.Tensor.shape) + return _internal_mutable_shape(); +} +inline void TypeProto_Tensor::set_allocated_shape(::mindspore::irpb::TensorShapeProto* shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete shape_; + } + if (shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(shape); + if (message_arena != submessage_arena) { + shape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, shape, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + shape_ = shape; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.TypeProto.Tensor.shape) +} + +// ------------------------------------------------------------------- + +// TypeProto_Sequence + +// repeated .mindspore.irpb.TypeProto elem_types = 1; +inline int TypeProto_Sequence::_internal_elem_types_size() const { + return elem_types_.size(); +} +inline int TypeProto_Sequence::elem_types_size() const { + return _internal_elem_types_size(); +} +inline void TypeProto_Sequence::clear_elem_types() { + elem_types_.Clear(); +} +inline ::mindspore::irpb::TypeProto* TypeProto_Sequence::mutable_elem_types(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TypeProto.Sequence.elem_types) + return elem_types_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TypeProto >* +TypeProto_Sequence::mutable_elem_types() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TypeProto.Sequence.elem_types) + return &elem_types_; +} +inline const ::mindspore::irpb::TypeProto& TypeProto_Sequence::_internal_elem_types(int index) const { + return elem_types_.Get(index); +} +inline const ::mindspore::irpb::TypeProto& TypeProto_Sequence::elem_types(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TypeProto.Sequence.elem_types) + return _internal_elem_types(index); +} +inline ::mindspore::irpb::TypeProto* TypeProto_Sequence::_internal_add_elem_types() { + return elem_types_.Add(); +} +inline ::mindspore::irpb::TypeProto* TypeProto_Sequence::add_elem_types() { + // @@protoc_insertion_point(field_add:mindspore.irpb.TypeProto.Sequence.elem_types) + return _internal_add_elem_types(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::TypeProto >& +TypeProto_Sequence::elem_types() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TypeProto.Sequence.elem_types) + return elem_types_; +} + +// ------------------------------------------------------------------- + +// TypeProto + +// optional .mindspore.irpb.DataType data_type = 1; +inline bool TypeProto::_internal_has_data_type() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TypeProto::has_data_type() const { + return _internal_has_data_type(); +} +inline void TypeProto::clear_data_type() { + data_type_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline ::mindspore::irpb::DataType TypeProto::_internal_data_type() const { + return static_cast< ::mindspore::irpb::DataType >(data_type_); +} +inline ::mindspore::irpb::DataType TypeProto::data_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TypeProto.data_type) + return _internal_data_type(); +} +inline void TypeProto::_internal_set_data_type(::mindspore::irpb::DataType value) { + assert(::mindspore::irpb::DataType_IsValid(value)); + _has_bits_[0] |= 0x00000001u; + data_type_ = value; +} +inline void TypeProto::set_data_type(::mindspore::irpb::DataType value) { + _internal_set_data_type(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TypeProto.data_type) +} + +// .mindspore.irpb.TypeProto.Tensor tensor_type = 2; +inline bool TypeProto::_internal_has_tensor_type() const { + return value_case() == kTensorType; +} +inline bool TypeProto::has_tensor_type() const { + return _internal_has_tensor_type(); +} +inline void TypeProto::set_has_tensor_type() { + _oneof_case_[0] = kTensorType; +} +inline void TypeProto::clear_tensor_type() { + if (_internal_has_tensor_type()) { + if (GetArena() == nullptr) { + delete value_.tensor_type_; + } + clear_has_value(); + } +} +inline ::mindspore::irpb::TypeProto_Tensor* TypeProto::release_tensor_type() { + // @@protoc_insertion_point(field_release:mindspore.irpb.TypeProto.tensor_type) + if (_internal_has_tensor_type()) { + clear_has_value(); + ::mindspore::irpb::TypeProto_Tensor* temp = value_.tensor_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::TypeProto_Tensor& TypeProto::_internal_tensor_type() const { + return _internal_has_tensor_type() + ? *value_.tensor_type_ + : *reinterpret_cast< ::mindspore::irpb::TypeProto_Tensor*>(&::mindspore::irpb::_TypeProto_Tensor_default_instance_); +} +inline const ::mindspore::irpb::TypeProto_Tensor& TypeProto::tensor_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TypeProto.tensor_type) + return _internal_tensor_type(); +} +inline ::mindspore::irpb::TypeProto_Tensor* TypeProto::unsafe_arena_release_tensor_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.TypeProto.tensor_type) + if (_internal_has_tensor_type()) { + clear_has_value(); + ::mindspore::irpb::TypeProto_Tensor* temp = value_.tensor_type_; + value_.tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_tensor_type(::mindspore::irpb::TypeProto_Tensor* tensor_type) { + clear_value(); + if (tensor_type) { + set_has_tensor_type(); + value_.tensor_type_ = tensor_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.TypeProto.tensor_type) +} +inline ::mindspore::irpb::TypeProto_Tensor* TypeProto::_internal_mutable_tensor_type() { + if (!_internal_has_tensor_type()) { + clear_value(); + set_has_tensor_type(); + value_.tensor_type_ = CreateMaybeMessage< ::mindspore::irpb::TypeProto_Tensor >(GetArena()); + } + return value_.tensor_type_; +} +inline ::mindspore::irpb::TypeProto_Tensor* TypeProto::mutable_tensor_type() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TypeProto.tensor_type) + return _internal_mutable_tensor_type(); +} + +// .mindspore.irpb.TypeProto.Sequence sequence_type = 3; +inline bool TypeProto::_internal_has_sequence_type() const { + return value_case() == kSequenceType; +} +inline bool TypeProto::has_sequence_type() const { + return _internal_has_sequence_type(); +} +inline void TypeProto::set_has_sequence_type() { + _oneof_case_[0] = kSequenceType; +} +inline void TypeProto::clear_sequence_type() { + if (_internal_has_sequence_type()) { + if (GetArena() == nullptr) { + delete value_.sequence_type_; + } + clear_has_value(); + } +} +inline ::mindspore::irpb::TypeProto_Sequence* TypeProto::release_sequence_type() { + // @@protoc_insertion_point(field_release:mindspore.irpb.TypeProto.sequence_type) + if (_internal_has_sequence_type()) { + clear_has_value(); + ::mindspore::irpb::TypeProto_Sequence* temp = value_.sequence_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.sequence_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::TypeProto_Sequence& TypeProto::_internal_sequence_type() const { + return _internal_has_sequence_type() + ? *value_.sequence_type_ + : *reinterpret_cast< ::mindspore::irpb::TypeProto_Sequence*>(&::mindspore::irpb::_TypeProto_Sequence_default_instance_); +} +inline const ::mindspore::irpb::TypeProto_Sequence& TypeProto::sequence_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TypeProto.sequence_type) + return _internal_sequence_type(); +} +inline ::mindspore::irpb::TypeProto_Sequence* TypeProto::unsafe_arena_release_sequence_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.TypeProto.sequence_type) + if (_internal_has_sequence_type()) { + clear_has_value(); + ::mindspore::irpb::TypeProto_Sequence* temp = value_.sequence_type_; + value_.sequence_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_sequence_type(::mindspore::irpb::TypeProto_Sequence* sequence_type) { + clear_value(); + if (sequence_type) { + set_has_sequence_type(); + value_.sequence_type_ = sequence_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.TypeProto.sequence_type) +} +inline ::mindspore::irpb::TypeProto_Sequence* TypeProto::_internal_mutable_sequence_type() { + if (!_internal_has_sequence_type()) { + clear_value(); + set_has_sequence_type(); + value_.sequence_type_ = CreateMaybeMessage< ::mindspore::irpb::TypeProto_Sequence >(GetArena()); + } + return value_.sequence_type_; +} +inline ::mindspore::irpb::TypeProto_Sequence* TypeProto::mutable_sequence_type() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TypeProto.sequence_type) + return _internal_mutable_sequence_type(); +} + +inline bool TypeProto::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void TypeProto::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline TypeProto::ValueCase TypeProto::value_case() const { + return TypeProto::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// ParameterProto + +// optional string name = 1; +inline bool ParameterProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ParameterProto::has_name() const { + return _internal_has_name(); +} +inline void ParameterProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ParameterProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ParameterProto.name) + return _internal_name(); +} +inline void ParameterProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ParameterProto.name) +} +inline std::string* ParameterProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ParameterProto.name) + return _internal_mutable_name(); +} +inline const std::string& ParameterProto::_internal_name() const { + return name_.Get(); +} +inline void ParameterProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ParameterProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.ParameterProto.name) +} +inline void ParameterProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.ParameterProto.name) +} +inline void ParameterProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.ParameterProto.name) +} +inline std::string* ParameterProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ParameterProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ParameterProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ParameterProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ParameterProto.name) +} + +// optional .mindspore.irpb.TypeProto type = 2; +inline bool ParameterProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || type_ != nullptr); + return value; +} +inline bool ParameterProto::has_type() const { + return _internal_has_type(); +} +inline void ParameterProto::clear_type() { + if (type_ != nullptr) type_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::TypeProto& ParameterProto::_internal_type() const { + const ::mindspore::irpb::TypeProto* p = type_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TypeProto_default_instance_); +} +inline const ::mindspore::irpb::TypeProto& ParameterProto::type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ParameterProto.type) + return _internal_type(); +} +inline void ParameterProto::unsafe_arena_set_allocated_type( + ::mindspore::irpb::TypeProto* type) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(type_); + } + type_ = type; + if (type) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ParameterProto.type) +} +inline ::mindspore::irpb::TypeProto* ParameterProto::release_type() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TypeProto* temp = type_; + type_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TypeProto* ParameterProto::unsafe_arena_release_type() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ParameterProto.type) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TypeProto* temp = type_; + type_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TypeProto* ParameterProto::_internal_mutable_type() { + _has_bits_[0] |= 0x00000002u; + if (type_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TypeProto>(GetArena()); + type_ = p; + } + return type_; +} +inline ::mindspore::irpb::TypeProto* ParameterProto::mutable_type() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ParameterProto.type) + return _internal_mutable_type(); +} +inline void ParameterProto::set_allocated_type(::mindspore::irpb::TypeProto* type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete type_; + } + if (type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(type); + if (message_arena != submessage_arena) { + type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, type, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + type_ = type; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ParameterProto.type) +} + +// optional .mindspore.irpb.ValueProto default_val = 3; +inline bool ParameterProto::_internal_has_default_val() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || default_val_ != nullptr); + return value; +} +inline bool ParameterProto::has_default_val() const { + return _internal_has_default_val(); +} +inline void ParameterProto::clear_default_val() { + if (default_val_ != nullptr) default_val_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::mindspore::irpb::ValueProto& ParameterProto::_internal_default_val() const { + const ::mindspore::irpb::ValueProto* p = default_val_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_ValueProto_default_instance_); +} +inline const ::mindspore::irpb::ValueProto& ParameterProto::default_val() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ParameterProto.default_val) + return _internal_default_val(); +} +inline void ParameterProto::unsafe_arena_set_allocated_default_val( + ::mindspore::irpb::ValueProto* default_val) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(default_val_); + } + default_val_ = default_val; + if (default_val) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ParameterProto.default_val) +} +inline ::mindspore::irpb::ValueProto* ParameterProto::release_default_val() { + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::ValueProto* temp = default_val_; + default_val_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::ValueProto* ParameterProto::unsafe_arena_release_default_val() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ParameterProto.default_val) + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::ValueProto* temp = default_val_; + default_val_ = nullptr; + return temp; +} +inline ::mindspore::irpb::ValueProto* ParameterProto::_internal_mutable_default_val() { + _has_bits_[0] |= 0x00000004u; + if (default_val_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::ValueProto>(GetArena()); + default_val_ = p; + } + return default_val_; +} +inline ::mindspore::irpb::ValueProto* ParameterProto::mutable_default_val() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ParameterProto.default_val) + return _internal_mutable_default_val(); +} +inline void ParameterProto::set_allocated_default_val(::mindspore::irpb::ValueProto* default_val) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete default_val_; + } + if (default_val) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(default_val); + if (message_arena != submessage_arena) { + default_val = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, default_val, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + default_val_ = default_val; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ParameterProto.default_val) +} + +// ------------------------------------------------------------------- + +// OutputProto + +// optional string name = 1; +inline bool OutputProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool OutputProto::has_name() const { + return _internal_has_name(); +} +inline void OutputProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& OutputProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OutputProto.name) + return _internal_name(); +} +inline void OutputProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.OutputProto.name) +} +inline std::string* OutputProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OutputProto.name) + return _internal_mutable_name(); +} +inline const std::string& OutputProto::_internal_name() const { + return name_.Get(); +} +inline void OutputProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OutputProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.OutputProto.name) +} +inline void OutputProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.OutputProto.name) +} +inline void OutputProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.OutputProto.name) +} +inline std::string* OutputProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OutputProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.OutputProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OutputProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.OutputProto.name) +} + +// optional .mindspore.irpb.TypeProto type = 2; +inline bool OutputProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || type_ != nullptr); + return value; +} +inline bool OutputProto::has_type() const { + return _internal_has_type(); +} +inline void OutputProto::clear_type() { + if (type_ != nullptr) type_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::TypeProto& OutputProto::_internal_type() const { + const ::mindspore::irpb::TypeProto* p = type_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TypeProto_default_instance_); +} +inline const ::mindspore::irpb::TypeProto& OutputProto::type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OutputProto.type) + return _internal_type(); +} +inline void OutputProto::unsafe_arena_set_allocated_type( + ::mindspore::irpb::TypeProto* type) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(type_); + } + type_ = type; + if (type) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.OutputProto.type) +} +inline ::mindspore::irpb::TypeProto* OutputProto::release_type() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TypeProto* temp = type_; + type_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TypeProto* OutputProto::unsafe_arena_release_type() { + // @@protoc_insertion_point(field_release:mindspore.irpb.OutputProto.type) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TypeProto* temp = type_; + type_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TypeProto* OutputProto::_internal_mutable_type() { + _has_bits_[0] |= 0x00000002u; + if (type_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TypeProto>(GetArena()); + type_ = p; + } + return type_; +} +inline ::mindspore::irpb::TypeProto* OutputProto::mutable_type() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OutputProto.type) + return _internal_mutable_type(); +} +inline void OutputProto::set_allocated_type(::mindspore::irpb::TypeProto* type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete type_; + } + if (type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(type); + if (message_arena != submessage_arena) { + type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, type, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + type_ = type; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.OutputProto.type) +} + +// ------------------------------------------------------------------- + +// InputProto + +// optional string name = 1; +inline bool InputProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool InputProto::has_name() const { + return _internal_has_name(); +} +inline void InputProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& InputProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.InputProto.name) + return _internal_name(); +} +inline void InputProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.InputProto.name) +} +inline std::string* InputProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.InputProto.name) + return _internal_mutable_name(); +} +inline const std::string& InputProto::_internal_name() const { + return name_.Get(); +} +inline void InputProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void InputProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.InputProto.name) +} +inline void InputProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.InputProto.name) +} +inline void InputProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.InputProto.name) +} +inline std::string* InputProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* InputProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.InputProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void InputProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.InputProto.name) +} + +// optional .mindspore.irpb.InputProto.EdgeType type = 2; +inline bool InputProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool InputProto::has_type() const { + return _internal_has_type(); +} +inline void InputProto::clear_type() { + type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::mindspore::irpb::InputProto_EdgeType InputProto::_internal_type() const { + return static_cast< ::mindspore::irpb::InputProto_EdgeType >(type_); +} +inline ::mindspore::irpb::InputProto_EdgeType InputProto::type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.InputProto.type) + return _internal_type(); +} +inline void InputProto::_internal_set_type(::mindspore::irpb::InputProto_EdgeType value) { + assert(::mindspore::irpb::InputProto_EdgeType_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + type_ = value; +} +inline void InputProto::set_type(::mindspore::irpb::InputProto_EdgeType value) { + _internal_set_type(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.InputProto.type) +} + +// ------------------------------------------------------------------- + +// NodeProto + +// repeated .mindspore.irpb.InputProto input = 1; +inline int NodeProto::_internal_input_size() const { + return input_.size(); +} +inline int NodeProto::input_size() const { + return _internal_input_size(); +} +inline void NodeProto::clear_input() { + input_.Clear(); +} +inline ::mindspore::irpb::InputProto* NodeProto::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.input) + return input_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::InputProto >* +NodeProto::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.NodeProto.input) + return &input_; +} +inline const ::mindspore::irpb::InputProto& NodeProto::_internal_input(int index) const { + return input_.Get(index); +} +inline const ::mindspore::irpb::InputProto& NodeProto::input(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.input) + return _internal_input(index); +} +inline ::mindspore::irpb::InputProto* NodeProto::_internal_add_input() { + return input_.Add(); +} +inline ::mindspore::irpb::InputProto* NodeProto::add_input() { + // @@protoc_insertion_point(field_add:mindspore.irpb.NodeProto.input) + return _internal_add_input(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::InputProto >& +NodeProto::input() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.NodeProto.input) + return input_; +} + +// optional string name = 2; +inline bool NodeProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool NodeProto::has_name() const { + return _internal_has_name(); +} +inline void NodeProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& NodeProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.name) + return _internal_name(); +} +inline void NodeProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NodeProto.name) +} +inline std::string* NodeProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.name) + return _internal_mutable_name(); +} +inline const std::string& NodeProto::_internal_name() const { + return name_.Get(); +} +inline void NodeProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.NodeProto.name) +} +inline void NodeProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.NodeProto.name) +} +inline void NodeProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.NodeProto.name) +} +inline std::string* NodeProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NodeProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NodeProto.name) +} + +// optional string op_type = 3; +inline bool NodeProto::_internal_has_op_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool NodeProto::has_op_type() const { + return _internal_has_op_type(); +} +inline void NodeProto::clear_op_type() { + op_type_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& NodeProto::op_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.op_type) + return _internal_op_type(); +} +inline void NodeProto::set_op_type(const std::string& value) { + _internal_set_op_type(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NodeProto.op_type) +} +inline std::string* NodeProto::mutable_op_type() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.op_type) + return _internal_mutable_op_type(); +} +inline const std::string& NodeProto::_internal_op_type() const { + return op_type_.Get(); +} +inline void NodeProto::_internal_set_op_type(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_op_type(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + op_type_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.NodeProto.op_type) +} +inline void NodeProto::set_op_type(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.NodeProto.op_type) +} +inline void NodeProto::set_op_type(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.NodeProto.op_type) +} +inline std::string* NodeProto::_internal_mutable_op_type() { + _has_bits_[0] |= 0x00000002u; + return op_type_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_op_type() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NodeProto.op_type) + if (!_internal_has_op_type()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return op_type_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_op_type(std::string* op_type) { + if (op_type != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + op_type_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), op_type, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NodeProto.op_type) +} + +// optional string scope = 4; +inline bool NodeProto::_internal_has_scope() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool NodeProto::has_scope() const { + return _internal_has_scope(); +} +inline void NodeProto::clear_scope() { + scope_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& NodeProto::scope() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.scope) + return _internal_scope(); +} +inline void NodeProto::set_scope(const std::string& value) { + _internal_set_scope(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NodeProto.scope) +} +inline std::string* NodeProto::mutable_scope() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.scope) + return _internal_mutable_scope(); +} +inline const std::string& NodeProto::_internal_scope() const { + return scope_.Get(); +} +inline void NodeProto::_internal_set_scope(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + scope_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_scope(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + scope_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.NodeProto.scope) +} +inline void NodeProto::set_scope(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + scope_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.NodeProto.scope) +} +inline void NodeProto::set_scope(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + scope_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.NodeProto.scope) +} +inline std::string* NodeProto::_internal_mutable_scope() { + _has_bits_[0] |= 0x00000004u; + return scope_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_scope() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NodeProto.scope) + if (!_internal_has_scope()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return scope_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_scope(std::string* scope) { + if (scope != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + scope_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), scope, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NodeProto.scope) +} + +// repeated .mindspore.irpb.AttributeProto attribute = 5; +inline int NodeProto::_internal_attribute_size() const { + return attribute_.size(); +} +inline int NodeProto::attribute_size() const { + return _internal_attribute_size(); +} +inline void NodeProto::clear_attribute() { + attribute_.Clear(); +} +inline ::mindspore::irpb::AttributeProto* NodeProto::mutable_attribute(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.attribute) + return attribute_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::AttributeProto >* +NodeProto::mutable_attribute() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.NodeProto.attribute) + return &attribute_; +} +inline const ::mindspore::irpb::AttributeProto& NodeProto::_internal_attribute(int index) const { + return attribute_.Get(index); +} +inline const ::mindspore::irpb::AttributeProto& NodeProto::attribute(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.attribute) + return _internal_attribute(index); +} +inline ::mindspore::irpb::AttributeProto* NodeProto::_internal_add_attribute() { + return attribute_.Add(); +} +inline ::mindspore::irpb::AttributeProto* NodeProto::add_attribute() { + // @@protoc_insertion_point(field_add:mindspore.irpb.NodeProto.attribute) + return _internal_add_attribute(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::AttributeProto >& +NodeProto::attribute() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.NodeProto.attribute) + return attribute_; +} + +// optional .mindspore.irpb.TypeProto output_type = 6; +inline bool NodeProto::_internal_has_output_type() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + PROTOBUF_ASSUME(!value || output_type_ != nullptr); + return value; +} +inline bool NodeProto::has_output_type() const { + return _internal_has_output_type(); +} +inline void NodeProto::clear_output_type() { + if (output_type_ != nullptr) output_type_->Clear(); + _has_bits_[0] &= ~0x00000020u; +} +inline const ::mindspore::irpb::TypeProto& NodeProto::_internal_output_type() const { + const ::mindspore::irpb::TypeProto* p = output_type_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TypeProto_default_instance_); +} +inline const ::mindspore::irpb::TypeProto& NodeProto::output_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.output_type) + return _internal_output_type(); +} +inline void NodeProto::unsafe_arena_set_allocated_output_type( + ::mindspore::irpb::TypeProto* output_type) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(output_type_); + } + output_type_ = output_type; + if (output_type) { + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.NodeProto.output_type) +} +inline ::mindspore::irpb::TypeProto* NodeProto::release_output_type() { + _has_bits_[0] &= ~0x00000020u; + ::mindspore::irpb::TypeProto* temp = output_type_; + output_type_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TypeProto* NodeProto::unsafe_arena_release_output_type() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NodeProto.output_type) + _has_bits_[0] &= ~0x00000020u; + ::mindspore::irpb::TypeProto* temp = output_type_; + output_type_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TypeProto* NodeProto::_internal_mutable_output_type() { + _has_bits_[0] |= 0x00000020u; + if (output_type_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TypeProto>(GetArena()); + output_type_ = p; + } + return output_type_; +} +inline ::mindspore::irpb::TypeProto* NodeProto::mutable_output_type() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.output_type) + return _internal_mutable_output_type(); +} +inline void NodeProto::set_allocated_output_type(::mindspore::irpb::TypeProto* output_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete output_type_; + } + if (output_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(output_type); + if (message_arena != submessage_arena) { + output_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, output_type, submessage_arena); + } + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + output_type_ = output_type; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NodeProto.output_type) +} + +// optional uint64 output_i = 7; +inline bool NodeProto::_internal_has_output_i() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool NodeProto::has_output_i() const { + return _internal_has_output_i(); +} +inline void NodeProto::clear_output_i() { + output_i_ = PROTOBUF_ULONGLONG(0); + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 NodeProto::_internal_output_i() const { + return output_i_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 NodeProto::output_i() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.output_i) + return _internal_output_i(); +} +inline void NodeProto::_internal_set_output_i(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _has_bits_[0] |= 0x00000040u; + output_i_ = value; +} +inline void NodeProto::set_output_i(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_output_i(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NodeProto.output_i) +} + +// optional string full_name = 8; +inline bool NodeProto::_internal_has_full_name() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool NodeProto::has_full_name() const { + return _internal_has_full_name(); +} +inline void NodeProto::clear_full_name() { + full_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& NodeProto::full_name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.full_name) + return _internal_full_name(); +} +inline void NodeProto::set_full_name(const std::string& value) { + _internal_set_full_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NodeProto.full_name) +} +inline std::string* NodeProto::mutable_full_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.full_name) + return _internal_mutable_full_name(); +} +inline const std::string& NodeProto::_internal_full_name() const { + return full_name_.Get(); +} +inline void NodeProto::_internal_set_full_name(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + full_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_full_name(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + full_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.NodeProto.full_name) +} +inline void NodeProto::set_full_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + full_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.NodeProto.full_name) +} +inline void NodeProto::set_full_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000008u; + full_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.NodeProto.full_name) +} +inline std::string* NodeProto::_internal_mutable_full_name() { + _has_bits_[0] |= 0x00000008u; + return full_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_full_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NodeProto.full_name) + if (!_internal_has_full_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return full_name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_full_name(std::string* full_name) { + if (full_name != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + full_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), full_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NodeProto.full_name) +} + +// optional string instance_name = 10; +inline bool NodeProto::_internal_has_instance_name() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool NodeProto::has_instance_name() const { + return _internal_has_instance_name(); +} +inline void NodeProto::clear_instance_name() { + instance_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000010u; +} +inline const std::string& NodeProto::instance_name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.NodeProto.instance_name) + return _internal_instance_name(); +} +inline void NodeProto::set_instance_name(const std::string& value) { + _internal_set_instance_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.NodeProto.instance_name) +} +inline std::string* NodeProto::mutable_instance_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.NodeProto.instance_name) + return _internal_mutable_instance_name(); +} +inline const std::string& NodeProto::_internal_instance_name() const { + return instance_name_.Get(); +} +inline void NodeProto::_internal_set_instance_name(const std::string& value) { + _has_bits_[0] |= 0x00000010u; + instance_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_instance_name(std::string&& value) { + _has_bits_[0] |= 0x00000010u; + instance_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.NodeProto.instance_name) +} +inline void NodeProto::set_instance_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000010u; + instance_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.NodeProto.instance_name) +} +inline void NodeProto::set_instance_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000010u; + instance_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.NodeProto.instance_name) +} +inline std::string* NodeProto::_internal_mutable_instance_name() { + _has_bits_[0] |= 0x00000010u; + return instance_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_instance_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.NodeProto.instance_name) + if (!_internal_has_instance_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000010u; + return instance_name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_instance_name(std::string* instance_name) { + if (instance_name != nullptr) { + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + instance_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), instance_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.NodeProto.instance_name) +} + +// ------------------------------------------------------------------- + +// ModelProto + +// optional int64 ir_version = 1; +inline bool ModelProto::_internal_has_ir_version() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool ModelProto::has_ir_version() const { + return _internal_has_ir_version(); +} +inline void ModelProto::clear_ir_version() { + ir_version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000008u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::_internal_ir_version() const { + return ir_version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::ir_version() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ModelProto.ir_version) + return _internal_ir_version(); +} +inline void ModelProto::_internal_set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000008u; + ir_version_ = value; +} +inline void ModelProto::set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_ir_version(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ModelProto.ir_version) +} + +// optional string domain = 2; +inline bool ModelProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ModelProto::has_domain() const { + return _internal_has_domain(); +} +inline void ModelProto::clear_domain() { + domain_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ModelProto::domain() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ModelProto.domain) + return _internal_domain(); +} +inline void ModelProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ModelProto.domain) +} +inline std::string* ModelProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ModelProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& ModelProto::_internal_domain() const { + return domain_.Get(); +} +inline void ModelProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ModelProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + domain_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.ModelProto.domain) +} +inline void ModelProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.ModelProto.domain) +} +inline void ModelProto::set_domain(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.ModelProto.domain) +} +inline std::string* ModelProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000001u; + return domain_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ModelProto::release_domain() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ModelProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return domain_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ModelProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + domain_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ModelProto.domain) +} + +// optional int64 model_version = 3; +inline bool ModelProto::_internal_has_model_version() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool ModelProto::has_model_version() const { + return _internal_has_model_version(); +} +inline void ModelProto::clear_model_version() { + model_version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000010u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::_internal_model_version() const { + return model_version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::model_version() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ModelProto.model_version) + return _internal_model_version(); +} +inline void ModelProto::_internal_set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000010u; + model_version_ = value; +} +inline void ModelProto::set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_model_version(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.ModelProto.model_version) +} + +// optional .mindspore.irpb.GraphProto graph = 4; +inline bool ModelProto::_internal_has_graph() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || graph_ != nullptr); + return value; +} +inline bool ModelProto::has_graph() const { + return _internal_has_graph(); +} +inline void ModelProto::clear_graph() { + if (graph_ != nullptr) graph_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::GraphProto& ModelProto::_internal_graph() const { + const ::mindspore::irpb::GraphProto* p = graph_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_GraphProto_default_instance_); +} +inline const ::mindspore::irpb::GraphProto& ModelProto::graph() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ModelProto.graph) + return _internal_graph(); +} +inline void ModelProto::unsafe_arena_set_allocated_graph( + ::mindspore::irpb::GraphProto* graph) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(graph_); + } + graph_ = graph; + if (graph) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ModelProto.graph) +} +inline ::mindspore::irpb::GraphProto* ModelProto::release_graph() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::GraphProto* temp = graph_; + graph_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::GraphProto* ModelProto::unsafe_arena_release_graph() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ModelProto.graph) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::GraphProto* temp = graph_; + graph_ = nullptr; + return temp; +} +inline ::mindspore::irpb::GraphProto* ModelProto::_internal_mutable_graph() { + _has_bits_[0] |= 0x00000002u; + if (graph_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::GraphProto>(GetArena()); + graph_ = p; + } + return graph_; +} +inline ::mindspore::irpb::GraphProto* ModelProto::mutable_graph() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ModelProto.graph) + return _internal_mutable_graph(); +} +inline void ModelProto::set_allocated_graph(::mindspore::irpb::GraphProto* graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete graph_; + } + if (graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(graph); + if (message_arena != submessage_arena) { + graph = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, graph, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + graph_ = graph; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ModelProto.graph) +} + +// optional .mindspore.irpb.OperatorSetProto metadata_operators = 5; +inline bool ModelProto::_internal_has_metadata_operators() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || metadata_operators_ != nullptr); + return value; +} +inline bool ModelProto::has_metadata_operators() const { + return _internal_has_metadata_operators(); +} +inline void ModelProto::clear_metadata_operators() { + if (metadata_operators_ != nullptr) metadata_operators_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::mindspore::irpb::OperatorSetProto& ModelProto::_internal_metadata_operators() const { + const ::mindspore::irpb::OperatorSetProto* p = metadata_operators_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_OperatorSetProto_default_instance_); +} +inline const ::mindspore::irpb::OperatorSetProto& ModelProto::metadata_operators() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.ModelProto.metadata_operators) + return _internal_metadata_operators(); +} +inline void ModelProto::unsafe_arena_set_allocated_metadata_operators( + ::mindspore::irpb::OperatorSetProto* metadata_operators) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(metadata_operators_); + } + metadata_operators_ = metadata_operators; + if (metadata_operators) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.ModelProto.metadata_operators) +} +inline ::mindspore::irpb::OperatorSetProto* ModelProto::release_metadata_operators() { + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::OperatorSetProto* temp = metadata_operators_; + metadata_operators_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::OperatorSetProto* ModelProto::unsafe_arena_release_metadata_operators() { + // @@protoc_insertion_point(field_release:mindspore.irpb.ModelProto.metadata_operators) + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::OperatorSetProto* temp = metadata_operators_; + metadata_operators_ = nullptr; + return temp; +} +inline ::mindspore::irpb::OperatorSetProto* ModelProto::_internal_mutable_metadata_operators() { + _has_bits_[0] |= 0x00000004u; + if (metadata_operators_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::OperatorSetProto>(GetArena()); + metadata_operators_ = p; + } + return metadata_operators_; +} +inline ::mindspore::irpb::OperatorSetProto* ModelProto::mutable_metadata_operators() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.ModelProto.metadata_operators) + return _internal_mutable_metadata_operators(); +} +inline void ModelProto::set_allocated_metadata_operators(::mindspore::irpb::OperatorSetProto* metadata_operators) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete metadata_operators_; + } + if (metadata_operators) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(metadata_operators); + if (message_arena != submessage_arena) { + metadata_operators = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, metadata_operators, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + metadata_operators_ = metadata_operators; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.ModelProto.metadata_operators) +} + +// ------------------------------------------------------------------- + +// OperatorProto + +// optional string name = 1; +inline bool OperatorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool OperatorProto::has_name() const { + return _internal_has_name(); +} +inline void OperatorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& OperatorProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OperatorProto.name) + return _internal_name(); +} +inline void OperatorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.OperatorProto.name) +} +inline std::string* OperatorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OperatorProto.name) + return _internal_mutable_name(); +} +inline const std::string& OperatorProto::_internal_name() const { + return name_.Get(); +} +inline void OperatorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OperatorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.OperatorProto.name) +} +inline void OperatorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.OperatorProto.name) +} +inline void OperatorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.OperatorProto.name) +} +inline std::string* OperatorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OperatorProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.OperatorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OperatorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.OperatorProto.name) +} + +// optional bytes config = 2; +inline bool OperatorProto::_internal_has_config() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool OperatorProto::has_config() const { + return _internal_has_config(); +} +inline void OperatorProto::clear_config() { + config_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& OperatorProto::config() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OperatorProto.config) + return _internal_config(); +} +inline void OperatorProto::set_config(const std::string& value) { + _internal_set_config(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.OperatorProto.config) +} +inline std::string* OperatorProto::mutable_config() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OperatorProto.config) + return _internal_mutable_config(); +} +inline const std::string& OperatorProto::_internal_config() const { + return config_.Get(); +} +inline void OperatorProto::_internal_set_config(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + config_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OperatorProto::set_config(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + config_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.OperatorProto.config) +} +inline void OperatorProto::set_config(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + config_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.OperatorProto.config) +} +inline void OperatorProto::set_config(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + config_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.OperatorProto.config) +} +inline std::string* OperatorProto::_internal_mutable_config() { + _has_bits_[0] |= 0x00000002u; + return config_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OperatorProto::release_config() { + // @@protoc_insertion_point(field_release:mindspore.irpb.OperatorProto.config) + if (!_internal_has_config()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return config_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OperatorProto::set_allocated_config(std::string* config) { + if (config != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + config_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), config, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.OperatorProto.config) +} + +// optional bytes obj_info = 3; +inline bool OperatorProto::_internal_has_obj_info() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool OperatorProto::has_obj_info() const { + return _internal_has_obj_info(); +} +inline void OperatorProto::clear_obj_info() { + obj_info_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& OperatorProto::obj_info() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OperatorProto.obj_info) + return _internal_obj_info(); +} +inline void OperatorProto::set_obj_info(const std::string& value) { + _internal_set_obj_info(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.OperatorProto.obj_info) +} +inline std::string* OperatorProto::mutable_obj_info() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OperatorProto.obj_info) + return _internal_mutable_obj_info(); +} +inline const std::string& OperatorProto::_internal_obj_info() const { + return obj_info_.Get(); +} +inline void OperatorProto::_internal_set_obj_info(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + obj_info_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OperatorProto::set_obj_info(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + obj_info_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.OperatorProto.obj_info) +} +inline void OperatorProto::set_obj_info(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + obj_info_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.OperatorProto.obj_info) +} +inline void OperatorProto::set_obj_info(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + obj_info_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.OperatorProto.obj_info) +} +inline std::string* OperatorProto::_internal_mutable_obj_info() { + _has_bits_[0] |= 0x00000004u; + return obj_info_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OperatorProto::release_obj_info() { + // @@protoc_insertion_point(field_release:mindspore.irpb.OperatorProto.obj_info) + if (!_internal_has_obj_info()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return obj_info_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OperatorProto::set_allocated_obj_info(std::string* obj_info) { + if (obj_info != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + obj_info_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), obj_info, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.OperatorProto.obj_info) +} + +// ------------------------------------------------------------------- + +// OperatorSetProto + +// repeated .mindspore.irpb.OperatorProto operators = 1; +inline int OperatorSetProto::_internal_operators_size() const { + return operators_.size(); +} +inline int OperatorSetProto::operators_size() const { + return _internal_operators_size(); +} +inline void OperatorSetProto::clear_operators() { + operators_.Clear(); +} +inline ::mindspore::irpb::OperatorProto* OperatorSetProto::mutable_operators(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OperatorSetProto.operators) + return operators_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OperatorProto >* +OperatorSetProto::mutable_operators() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.OperatorSetProto.operators) + return &operators_; +} +inline const ::mindspore::irpb::OperatorProto& OperatorSetProto::_internal_operators(int index) const { + return operators_.Get(index); +} +inline const ::mindspore::irpb::OperatorProto& OperatorSetProto::operators(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OperatorSetProto.operators) + return _internal_operators(index); +} +inline ::mindspore::irpb::OperatorProto* OperatorSetProto::_internal_add_operators() { + return operators_.Add(); +} +inline ::mindspore::irpb::OperatorProto* OperatorSetProto::add_operators() { + // @@protoc_insertion_point(field_add:mindspore.irpb.OperatorSetProto.operators) + return _internal_add_operators(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OperatorProto >& +OperatorSetProto::operators() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.OperatorSetProto.operators) + return operators_; +} + +// optional string summary = 2; +inline bool OperatorSetProto::_internal_has_summary() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool OperatorSetProto::has_summary() const { + return _internal_has_summary(); +} +inline void OperatorSetProto::clear_summary() { + summary_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& OperatorSetProto::summary() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.OperatorSetProto.summary) + return _internal_summary(); +} +inline void OperatorSetProto::set_summary(const std::string& value) { + _internal_set_summary(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.OperatorSetProto.summary) +} +inline std::string* OperatorSetProto::mutable_summary() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.OperatorSetProto.summary) + return _internal_mutable_summary(); +} +inline const std::string& OperatorSetProto::_internal_summary() const { + return summary_.Get(); +} +inline void OperatorSetProto::_internal_set_summary(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + summary_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OperatorSetProto::set_summary(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + summary_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.OperatorSetProto.summary) +} +inline void OperatorSetProto::set_summary(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + summary_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.OperatorSetProto.summary) +} +inline void OperatorSetProto::set_summary(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + summary_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.OperatorSetProto.summary) +} +inline std::string* OperatorSetProto::_internal_mutable_summary() { + _has_bits_[0] |= 0x00000001u; + return summary_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OperatorSetProto::release_summary() { + // @@protoc_insertion_point(field_release:mindspore.irpb.OperatorSetProto.summary) + if (!_internal_has_summary()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return summary_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OperatorSetProto::set_allocated_summary(std::string* summary) { + if (summary != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + summary_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), summary, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.OperatorSetProto.summary) +} + +// ------------------------------------------------------------------- + +// GraphProto + +// repeated .mindspore.irpb.NodeProto node = 1; +inline int GraphProto::_internal_node_size() const { + return node_.size(); +} +inline int GraphProto::node_size() const { + return _internal_node_size(); +} +inline void GraphProto::clear_node() { + node_.Clear(); +} +inline ::mindspore::irpb::NodeProto* GraphProto::mutable_node(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.GraphProto.node) + return node_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NodeProto >* +GraphProto::mutable_node() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.GraphProto.node) + return &node_; +} +inline const ::mindspore::irpb::NodeProto& GraphProto::_internal_node(int index) const { + return node_.Get(index); +} +inline const ::mindspore::irpb::NodeProto& GraphProto::node(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.GraphProto.node) + return _internal_node(index); +} +inline ::mindspore::irpb::NodeProto* GraphProto::_internal_add_node() { + return node_.Add(); +} +inline ::mindspore::irpb::NodeProto* GraphProto::add_node() { + // @@protoc_insertion_point(field_add:mindspore.irpb.GraphProto.node) + return _internal_add_node(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NodeProto >& +GraphProto::node() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.GraphProto.node) + return node_; +} + +// optional string name = 2; +inline bool GraphProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool GraphProto::has_name() const { + return _internal_has_name(); +} +inline void GraphProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& GraphProto::name() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.GraphProto.name) + return _internal_name(); +} +inline void GraphProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.GraphProto.name) +} +inline std::string* GraphProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.GraphProto.name) + return _internal_mutable_name(); +} +inline const std::string& GraphProto::_internal_name() const { + return name_.Get(); +} +inline void GraphProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void GraphProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.GraphProto.name) +} +inline void GraphProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.GraphProto.name) +} +inline void GraphProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.GraphProto.name) +} +inline std::string* GraphProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* GraphProto::release_name() { + // @@protoc_insertion_point(field_release:mindspore.irpb.GraphProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void GraphProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.GraphProto.name) +} + +// repeated .mindspore.irpb.ParameterProto parameters = 3; +inline int GraphProto::_internal_parameters_size() const { + return parameters_.size(); +} +inline int GraphProto::parameters_size() const { + return _internal_parameters_size(); +} +inline void GraphProto::clear_parameters() { + parameters_.Clear(); +} +inline ::mindspore::irpb::ParameterProto* GraphProto::mutable_parameters(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.GraphProto.parameters) + return parameters_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ParameterProto >* +GraphProto::mutable_parameters() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.GraphProto.parameters) + return ¶meters_; +} +inline const ::mindspore::irpb::ParameterProto& GraphProto::_internal_parameters(int index) const { + return parameters_.Get(index); +} +inline const ::mindspore::irpb::ParameterProto& GraphProto::parameters(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.GraphProto.parameters) + return _internal_parameters(index); +} +inline ::mindspore::irpb::ParameterProto* GraphProto::_internal_add_parameters() { + return parameters_.Add(); +} +inline ::mindspore::irpb::ParameterProto* GraphProto::add_parameters() { + // @@protoc_insertion_point(field_add:mindspore.irpb.GraphProto.parameters) + return _internal_add_parameters(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::ParameterProto >& +GraphProto::parameters() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.GraphProto.parameters) + return parameters_; +} + +// repeated .mindspore.irpb.OutputProto outputs = 4; +inline int GraphProto::_internal_outputs_size() const { + return outputs_.size(); +} +inline int GraphProto::outputs_size() const { + return _internal_outputs_size(); +} +inline void GraphProto::clear_outputs() { + outputs_.Clear(); +} +inline ::mindspore::irpb::OutputProto* GraphProto::mutable_outputs(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.GraphProto.outputs) + return outputs_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OutputProto >* +GraphProto::mutable_outputs() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.GraphProto.outputs) + return &outputs_; +} +inline const ::mindspore::irpb::OutputProto& GraphProto::_internal_outputs(int index) const { + return outputs_.Get(index); +} +inline const ::mindspore::irpb::OutputProto& GraphProto::outputs(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.GraphProto.outputs) + return _internal_outputs(index); +} +inline ::mindspore::irpb::OutputProto* GraphProto::_internal_add_outputs() { + return outputs_.Add(); +} +inline ::mindspore::irpb::OutputProto* GraphProto::add_outputs() { + // @@protoc_insertion_point(field_add:mindspore.irpb.GraphProto.outputs) + return _internal_add_outputs(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::OutputProto >& +GraphProto::outputs() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.GraphProto.outputs) + return outputs_; +} + +// repeated .mindspore.irpb.NamedValueProto const_vals = 5; +inline int GraphProto::_internal_const_vals_size() const { + return const_vals_.size(); +} +inline int GraphProto::const_vals_size() const { + return _internal_const_vals_size(); +} +inline void GraphProto::clear_const_vals() { + const_vals_.Clear(); +} +inline ::mindspore::irpb::NamedValueProto* GraphProto::mutable_const_vals(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.GraphProto.const_vals) + return const_vals_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >* +GraphProto::mutable_const_vals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.GraphProto.const_vals) + return &const_vals_; +} +inline const ::mindspore::irpb::NamedValueProto& GraphProto::_internal_const_vals(int index) const { + return const_vals_.Get(index); +} +inline const ::mindspore::irpb::NamedValueProto& GraphProto::const_vals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.GraphProto.const_vals) + return _internal_const_vals(index); +} +inline ::mindspore::irpb::NamedValueProto* GraphProto::_internal_add_const_vals() { + return const_vals_.Add(); +} +inline ::mindspore::irpb::NamedValueProto* GraphProto::add_const_vals() { + // @@protoc_insertion_point(field_add:mindspore.irpb.GraphProto.const_vals) + return _internal_add_const_vals(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::NamedValueProto >& +GraphProto::const_vals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.GraphProto.const_vals) + return const_vals_; +} + +// ------------------------------------------------------------------- + +// TensorProto + +// repeated int64 dims = 1; +inline int TensorProto::_internal_dims_size() const { + return dims_.size(); +} +inline int TensorProto::dims_size() const { + return _internal_dims_size(); +} +inline void TensorProto::clear_dims() { + dims_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_dims(int index) const { + return dims_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::dims(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.dims) + return _internal_dims(index); +} +inline void TensorProto::set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.dims) +} +inline void TensorProto::_internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Add(value); +} +inline void TensorProto::add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_dims(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorProto.dims) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_dims() const { + return dims_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::dims() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorProto.dims) + return _internal_dims(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_dims() { + return &dims_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_dims() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorProto.dims) + return _internal_mutable_dims(); +} + +// optional .mindspore.irpb.DataType data_type = 2; +inline bool TensorProto::_internal_has_data_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TensorProto::has_data_type() const { + return _internal_has_data_type(); +} +inline void TensorProto::clear_data_type() { + data_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::mindspore::irpb::DataType TensorProto::_internal_data_type() const { + return static_cast< ::mindspore::irpb::DataType >(data_type_); +} +inline ::mindspore::irpb::DataType TensorProto::data_type() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.data_type) + return _internal_data_type(); +} +inline void TensorProto::_internal_set_data_type(::mindspore::irpb::DataType value) { + assert(::mindspore::irpb::DataType_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + data_type_ = value; +} +inline void TensorProto::set_data_type(::mindspore::irpb::DataType value) { + _internal_set_data_type(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.data_type) +} + +// repeated float float_data = 3 [packed = true]; +inline int TensorProto::_internal_float_data_size() const { + return float_data_.size(); +} +inline int TensorProto::float_data_size() const { + return _internal_float_data_size(); +} +inline void TensorProto::clear_float_data() { + float_data_.Clear(); +} +inline float TensorProto::_internal_float_data(int index) const { + return float_data_.Get(index); +} +inline float TensorProto::float_data(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.float_data) + return _internal_float_data(index); +} +inline void TensorProto::set_float_data(int index, float value) { + float_data_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.float_data) +} +inline void TensorProto::_internal_add_float_data(float value) { + float_data_.Add(value); +} +inline void TensorProto::add_float_data(float value) { + _internal_add_float_data(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorProto.float_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::_internal_float_data() const { + return float_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::float_data() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorProto.float_data) + return _internal_float_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::_internal_mutable_float_data() { + return &float_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::mutable_float_data() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorProto.float_data) + return _internal_mutable_float_data(); +} + +// repeated int32 int32_data = 4 [packed = true]; +inline int TensorProto::_internal_int32_data_size() const { + return int32_data_.size(); +} +inline int TensorProto::int32_data_size() const { + return _internal_int32_data_size(); +} +inline void TensorProto::clear_int32_data() { + int32_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_int32_data(int index) const { + return int32_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::int32_data(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.int32_data) + return _internal_int32_data(index); +} +inline void TensorProto::set_int32_data(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + int32_data_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.int32_data) +} +inline void TensorProto::_internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value) { + int32_data_.Add(value); +} +inline void TensorProto::add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_int32_data(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorProto.int32_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::_internal_int32_data() const { + return int32_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::int32_data() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorProto.int32_data) + return _internal_int32_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::_internal_mutable_int32_data() { + return &int32_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::mutable_int32_data() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorProto.int32_data) + return _internal_mutable_int32_data(); +} + +// repeated int64 int64_data = 5 [packed = true]; +inline int TensorProto::_internal_int64_data_size() const { + return int64_data_.size(); +} +inline int TensorProto::int64_data_size() const { + return _internal_int64_data_size(); +} +inline void TensorProto::clear_int64_data() { + int64_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_int64_data(int index) const { + return int64_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::int64_data(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.int64_data) + return _internal_int64_data(index); +} +inline void TensorProto::set_int64_data(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_data_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.int64_data) +} +inline void TensorProto::_internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_data_.Add(value); +} +inline void TensorProto::add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_int64_data(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorProto.int64_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_int64_data() const { + return int64_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::int64_data() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorProto.int64_data) + return _internal_int64_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_int64_data() { + return &int64_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_int64_data() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorProto.int64_data) + return _internal_mutable_int64_data(); +} + +// repeated double double_data = 6 [packed = true]; +inline int TensorProto::_internal_double_data_size() const { + return double_data_.size(); +} +inline int TensorProto::double_data_size() const { + return _internal_double_data_size(); +} +inline void TensorProto::clear_double_data() { + double_data_.Clear(); +} +inline double TensorProto::_internal_double_data(int index) const { + return double_data_.Get(index); +} +inline double TensorProto::double_data(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.double_data) + return _internal_double_data(index); +} +inline void TensorProto::set_double_data(int index, double value) { + double_data_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.double_data) +} +inline void TensorProto::_internal_add_double_data(double value) { + double_data_.Add(value); +} +inline void TensorProto::add_double_data(double value) { + _internal_add_double_data(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorProto.double_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::_internal_double_data() const { + return double_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::double_data() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorProto.double_data) + return _internal_double_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::_internal_mutable_double_data() { + return &double_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::mutable_double_data() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorProto.double_data) + return _internal_mutable_double_data(); +} + +// repeated uint64 uint64_data = 7 [packed = true]; +inline int TensorProto::_internal_uint64_data_size() const { + return uint64_data_.size(); +} +inline int TensorProto::uint64_data_size() const { + return _internal_uint64_data_size(); +} +inline void TensorProto::clear_uint64_data() { + uint64_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::_internal_uint64_data(int index) const { + return uint64_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::uint64_data(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.uint64_data) + return _internal_uint64_data(index); +} +inline void TensorProto::set_uint64_data(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_data_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.uint64_data) +} +inline void TensorProto::_internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_data_.Add(value); +} +inline void TensorProto::add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_add_uint64_data(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.TensorProto.uint64_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::_internal_uint64_data() const { + return uint64_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::uint64_data() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.TensorProto.uint64_data) + return _internal_uint64_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::_internal_mutable_uint64_data() { + return &uint64_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::mutable_uint64_data() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.TensorProto.uint64_data) + return _internal_mutable_uint64_data(); +} + +// optional bytes raw_data = 8; +inline bool TensorProto::_internal_has_raw_data() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorProto::has_raw_data() const { + return _internal_has_raw_data(); +} +inline void TensorProto::clear_raw_data() { + raw_data_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorProto::raw_data() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.TensorProto.raw_data) + return _internal_raw_data(); +} +inline void TensorProto::set_raw_data(const std::string& value) { + _internal_set_raw_data(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.TensorProto.raw_data) +} +inline std::string* TensorProto::mutable_raw_data() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.TensorProto.raw_data) + return _internal_mutable_raw_data(); +} +inline const std::string& TensorProto::_internal_raw_data() const { + return raw_data_.Get(); +} +inline void TensorProto::_internal_set_raw_data(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorProto::set_raw_data(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + raw_data_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.TensorProto.raw_data) +} +inline void TensorProto::set_raw_data(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.TensorProto.raw_data) +} +inline void TensorProto::set_raw_data(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.TensorProto.raw_data) +} +inline std::string* TensorProto::_internal_mutable_raw_data() { + _has_bits_[0] |= 0x00000001u; + return raw_data_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorProto::release_raw_data() { + // @@protoc_insertion_point(field_release:mindspore.irpb.TensorProto.raw_data) + if (!_internal_has_raw_data()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return raw_data_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorProto::set_allocated_raw_data(std::string* raw_data) { + if (raw_data != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + raw_data_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), raw_data, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.TensorProto.raw_data) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace irpb +} // namespace mindspore + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< ::mindspore::irpb::InputProto_EdgeType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::mindspore::irpb::InputProto_EdgeType>() { + return ::mindspore::irpb::InputProto_EdgeType_descriptor(); +} +template <> struct is_proto_enum< ::mindspore::irpb::Version> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::mindspore::irpb::Version>() { + return ::mindspore::irpb::Version_descriptor(); +} +template <> struct is_proto_enum< ::mindspore::irpb::DataType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::mindspore::irpb::DataType>() { + return ::mindspore::irpb::DataType_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_mindspore_5fanf_5fir_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.proto b/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.proto new file mode 100644 index 0000000000000000000000000000000000000000..f221e6d7f75358276382287b6741050bf51cb13f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/mindspore_anf_ir.proto @@ -0,0 +1,353 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package mindspore.irpb; + + +// Versioning +enum Version { + // unknown version + UNKNOWWN_VERSION = 0; + + // Initial version (IR VERSION 1), published on Sep 23, 2019 + IR_VERSION = 0x0000000000000001; +} + +// Data type definition +enum DataType { + DT_UNDEFINED = 0; + // Basic types. + DT_BOOL = 1; // bool + + DT_INT8 = 2; // int8_t + DT_INT16 = 3; // int16_t + DT_INT32 = 4; // int32_t + DT_INT64 = 5; // int64_t + + DT_UINT8 = 6; // uint8_t + DT_UINT16 = 7; // uint16_t + DT_UINT32 = 8; // uint32_t + DT_UINT64 = 9; // uint64_t + + DT_FLOAT16 = 10; // float 16 + DT_FLOAT32 = 11; // float 32 + DT_FLOAT64 = 12; // float 64 + + DT_STRING = 13; // string + DT_TENSOR = 14; // tensor + DT_GRAPH = 15; // graph + + // list type + DT_BOOLS = 16; // list of bool + + DT_INTS8 = 17; // list of int8_t + DT_INTS16 = 18; // list of int16_t + DT_INTS32 = 19; // list of int32_t + DT_INTS64 = 20; // list of int64_t + + DT_UINTS8 = 21; // list of uint8_t + DT_UINTS16 = 22; // list of uint16_t + DT_UINTS32 = 23; // list of uint32_t + DT_UINTS64 = 24; // list of uint64_t + + DT_FLOATS16 = 25; // list of float16 + DT_FLOATS32 = 26; // list of float32 + DT_FLOATS64 = 27; // list of float64 + + DT_STRINGS = 28; // list of string + DT_TENSORS = 29; // list of tensor + DT_GRAPHS = 30; // list of graph + + DT_TUPLE = 31; // tuple + DT_LIST = 32; // list + DT_DICT = 33; // dictionary + + // other types + DT_NONE = 34; // None + DT_SYM_INST = 35; // Symbolic Key Instance + + // type related type + DT_BASE_INT = 36; // type generic int + DT_BASE_UINT = 37; // type generate unsigned int + DT_BASE_FLOAT = 38; // type generate float + DT_TYPE = 39; // type type + DT_ANY = 40; // type any + DT_REFKEY = 41; // type refkey + DT_REF = 42; // type ref + DT_COMPLEX64 = 43; // list of complex64 + DT_COMPLEX128 = 44; // list of complex128 + DT_BASE_COMPLEX = 45; // type generate complex + + // bfloat type + DT_BFLOAT16 = 46; // bfloat16 + DT_BFLOATS16 = 47; // list of bfloat16 + + // quant type + DT_INT4 = 48; // int4 + + // slice type + DT_SLICE = 49; +} + +// Value definition for attribute value or parameter default value +message ValueProto { + // data type of value + optional DataType dtype = 1; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional bool bool_val = 2; // bool + optional int64 int_val = 3; // int + optional uint64 uint_val = 4; // uint + optional float float_val = 5; // float + optional double double_val = 6; // double + optional string str_val = 7; // string + optional TensorProto tensor_val = 8; // tensor value + optional GraphProto graph = 9; // graph + + repeated bool bool_vals = 10; // list of bool + repeated int64 int_vals = 11; // list of int + repeated uint64 uint_vals = 12; // list of uint + repeated float float_vals = 13; // list of float + repeated double double_vals = 14; // list of double + repeated string str_vals = 15; // list of string + repeated TensorProto tensor_vals = 16; // list of tensor value + repeated GraphProto graphs = 17; // list of graph + + // tuple or list + repeated ValueProto values = 18; // tuple, list of value + + // dictionary + repeated NamedValueProto dict_val = 19; // dictionary info + + // filed for type type + optional TypeProto type_val = 20; // type type info +} + +message AttributeProto { + optional string name = 1; // attribute name + optional ValueProto value = 2; // attribute value +} + +message NamedValueProto { + optional string key = 1; // attribute name + optional ValueProto value = 2; // attribute value +} + +// Defines a tensor shape. +message TensorShapeProto { + // One dimension of the tensor. + message Dimension { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). + optional int64 size = 1; + + // Optional name of the tensor dimension. + optional string name = 2; + }; + + repeated Dimension dim = 1; +} + +// Types for graph input(parameter) and output +message TypeProto { + + message Tensor { + // This field MUST have a valid DataType value except DT_TENSOR + optional DataType elem_type = 1; + optional TensorShapeProto shape = 2; // for scalar, this field is not set + } + + // tuple type + message Sequence { + // The type and optional shape of elements of the tuple. + repeated TypeProto elem_types = 1; + }; + + // data type + optional DataType data_type = 1; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 2; + + // The type of a tuple. + Sequence sequence_type = 3; + } +} + +// Defines information on graph parameters, including the name, the type, and +// the default value of parameter if exists. +message ParameterProto { + optional string name = 1; // parameter name + optional TypeProto type = 2; // parameter type + optional ValueProto default_val = 3; // default value of parameter if exists +} + +// Defines graph output information +message OutputProto { + optional string name = 1; // output node name + optional TypeProto type = 2; // output node type +} + +// Define node input information +message InputProto { + enum EdgeType { + DATA_EDGE = 0; // data edge + CONTROL_EDGE = 1; // control edge + } + + optional string name = 1; + optional EdgeType type = 2; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated InputProto input = 1; // namespace Value + optional string name = 2; // namespace Value + + // The symbolic identifier of the Operator to execute. + optional string op_type = 3; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string scope = 4; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // Optional type info of this node + optional TypeProto output_type = 6; + + // other fields for debug + optional uint64 output_i = 7; + + // The full_name_with_scope of CNode + optional string full_name = 8; + + // Note: Id 9 is reserved for the source_address field of the debugger, please see debug_graph.proto + + // As same as the IR file instance name field. + optional string instance_name = 10; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // ir version + optional int64 ir_version = 1; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 2; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 3; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 4; + + // metadata info of operators + optional OperatorSetProto metadata_operators = 5; +}; + +message OperatorProto { + optional string name = 1; // used as key, must be distinct + optional bytes config = 2; // operator config info + optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name +}; + +message OperatorSetProto { + repeated OperatorProto operators = 1; + optional string summary = 2; // summary info of operators, e.g. file position of operators file +} + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // The parameters(inputs) and outputs of the graph. + repeated ParameterProto parameters = 3; + repeated OutputProto outputs = 4; + + // Constants used in this graph + repeated NamedValueProto const_vals = 5; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid DataType value except DT_TENSOR + optional DataType data_type = 2; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float values + repeated float float_data = 3 [packed = true]; + + // For int32, uint8, int8, uint16, int16, and bool values + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, or BOOL + repeated int32 int32_data = 4 [packed = true]; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 5 [packed = true]; + + // For double + // When this field is present, the data_type field MUST be DOUBLE + repeated double double_data = 6 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 7 [packed = true]; + + // Store raw tensor content. When this raw_data field is used to store tensor value, + // elements MUST be stored in as fixed-width, little-endian order. + optional bytes raw_data = 8; +} diff --git a/plugins/mindstudio-insight-plugins/proto/mindspore_summary.pb.cc b/plugins/mindstudio-insight-plugins/proto/mindspore_summary.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..2835d2cb0440bb53fc8389908cdd7849ab9cfa94 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/mindspore_summary.pb.cc @@ -0,0 +1,6951 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mindspore_summary.proto + +#include "mindspore_summary.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<4> scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<5> scc_info_Explain_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Benchmark_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Explanation_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Explain_Hoc_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_HocLayer_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Inference_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Metadata_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<3> scc_info_LossLandscape_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Summary_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Summary_Histogram_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Image_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<4> scc_info_Summary_Value_mindspore_5fsummary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_mindspore_5fanf_5fir_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorProto_mindspore_5fanf_5fir_2eproto; +namespace mindspore { +namespace irpb { +class EventDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr version_; + const ::mindspore::irpb::GraphProto* graph_def_; + const ::mindspore::irpb::Summary* summary_; + const ::mindspore::irpb::Explain* explain_; +} _Event_default_instance_; +class LossLandscape_PointDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _LossLandscape_Point_default_instance_; +class LossLandscape_LossPathDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _LossLandscape_LossPath_default_instance_; +class LossLandscape_MetadataDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _LossLandscape_Metadata_default_instance_; +class LossLandscapeDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _LossLandscape_default_instance_; +class Summary_ImageDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_Image_default_instance_; +class Summary_Histogram_bucketDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_Histogram_bucket_default_instance_; +class Summary_HistogramDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_Histogram_default_instance_; +class Summary_ValueDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + float scalar_value_; + const ::mindspore::irpb::Summary_Image* image_; + const ::mindspore::irpb::TensorProto* tensor_; + const ::mindspore::irpb::Summary_Histogram* histogram_; + const ::mindspore::irpb::LossLandscape* loss_landscape_; +} _Summary_Value_default_instance_; +class SummaryDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_default_instance_; +class Explain_InferenceDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_Inference_default_instance_; +class Explain_ExplanationDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_Explanation_default_instance_; +class Explain_BenchmarkDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_Benchmark_default_instance_; +class Explain_MetadataDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_Metadata_default_instance_; +class Explain_HocLayerDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_HocLayer_default_instance_; +class Explain_HocDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_Hoc_default_instance_; +class ExplainDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Explain_default_instance_; +} // namespace irpb +} // namespace mindspore +static void InitDefaultsscc_info_Event_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Event_default_instance_; + new (ptr) ::mindspore::irpb::Event(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Event::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<3> scc_info_Event_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 3, 0, InitDefaultsscc_info_Event_mindspore_5fsummary_2eproto}, { + &scc_info_AttributeProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_Summary_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_Explain_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_default_instance_; + new (ptr) ::mindspore::irpb::Explain(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<5> scc_info_Explain_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 5, 0, InitDefaultsscc_info_Explain_mindspore_5fsummary_2eproto}, { + &scc_info_Explain_Inference_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Explanation_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Benchmark_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Metadata_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Hoc_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_Explain_Benchmark_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_Benchmark_default_instance_; + new (ptr) ::mindspore::irpb::Explain_Benchmark(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain_Benchmark::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Benchmark_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Explain_Benchmark_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_Explain_Explanation_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_Explanation_default_instance_; + new (ptr) ::mindspore::irpb::Explain_Explanation(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain_Explanation::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Explanation_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Explain_Explanation_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_Explain_Hoc_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_Hoc_default_instance_; + new (ptr) ::mindspore::irpb::Explain_Hoc(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain_Hoc::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Explain_Hoc_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_Explain_Hoc_mindspore_5fsummary_2eproto}, { + &scc_info_Explain_HocLayer_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_Explain_HocLayer_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_HocLayer_default_instance_; + new (ptr) ::mindspore::irpb::Explain_HocLayer(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain_HocLayer::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_HocLayer_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Explain_HocLayer_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_Explain_Inference_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_Inference_default_instance_; + new (ptr) ::mindspore::irpb::Explain_Inference(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain_Inference::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Inference_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Explain_Inference_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_Explain_Metadata_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Explain_Metadata_default_instance_; + new (ptr) ::mindspore::irpb::Explain_Metadata(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Explain_Metadata::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Explain_Metadata_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Explain_Metadata_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_LossLandscape_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_LossLandscape_default_instance_; + new (ptr) ::mindspore::irpb::LossLandscape(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::LossLandscape::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<3> scc_info_LossLandscape_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 3, 0, InitDefaultsscc_info_LossLandscape_mindspore_5fsummary_2eproto}, { + &scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_LossLandscape_LossPath_default_instance_; + new (ptr) ::mindspore::irpb::LossLandscape_LossPath(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::LossLandscape_LossPath::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto}, { + &scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_LossLandscape_Metadata_default_instance_; + new (ptr) ::mindspore::irpb::LossLandscape_Metadata(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::LossLandscape_Metadata::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_LossLandscape_Point_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_LossLandscape_Point_default_instance_; + new (ptr) ::mindspore::irpb::LossLandscape_Point(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::LossLandscape_Point::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_LossLandscape_Point_mindspore_5fsummary_2eproto}, { + &scc_info_TensorProto_mindspore_5fanf_5fir_2eproto.base,}}; + +static void InitDefaultsscc_info_Summary_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Summary_default_instance_; + new (ptr) ::mindspore::irpb::Summary(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Summary::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Summary_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_Summary_mindspore_5fsummary_2eproto}, { + &scc_info_Summary_Value_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_Summary_Histogram_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Summary_Histogram_default_instance_; + new (ptr) ::mindspore::irpb::Summary_Histogram(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Summary_Histogram::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Summary_Histogram_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_Summary_Histogram_mindspore_5fsummary_2eproto}, { + &scc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto.base,}}; + +static void InitDefaultsscc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Summary_Histogram_bucket_default_instance_; + new (ptr) ::mindspore::irpb::Summary_Histogram_bucket(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Summary_Histogram_bucket::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_Summary_Image_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Summary_Image_default_instance_; + new (ptr) ::mindspore::irpb::Summary_Image(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Summary_Image::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Image_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Summary_Image_mindspore_5fsummary_2eproto}, {}}; + +static void InitDefaultsscc_info_Summary_Value_mindspore_5fsummary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::mindspore::irpb::_Summary_Value_default_instance_; + new (ptr) ::mindspore::irpb::Summary_Value(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::mindspore::irpb::Summary_Value::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<4> scc_info_Summary_Value_mindspore_5fsummary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 4, 0, InitDefaultsscc_info_Summary_Value_mindspore_5fsummary_2eproto}, { + &scc_info_Summary_Image_mindspore_5fsummary_2eproto.base, + &scc_info_TensorProto_mindspore_5fanf_5fir_2eproto.base, + &scc_info_Summary_Histogram_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_mindspore_5fsummary_2eproto.base,}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_mindspore_5fsummary_2eproto[17]; +static constexpr ::PROTOBUF_NAMESPACE_ID::EnumDescriptor const** file_level_enum_descriptors_mindspore_5fsummary_2eproto = nullptr; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_mindspore_5fsummary_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_mindspore_5fsummary_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Event, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Event, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Event, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Event, wall_time_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Event, step_), + offsetof(::mindspore::irpb::EventDefaultTypeInternal, version_), + offsetof(::mindspore::irpb::EventDefaultTypeInternal, graph_def_), + offsetof(::mindspore::irpb::EventDefaultTypeInternal, summary_), + offsetof(::mindspore::irpb::EventDefaultTypeInternal, explain_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Event, what_), + 0, + 1, + ~0u, + ~0u, + ~0u, + ~0u, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Point, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Point, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Point, x_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Point, y_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Point, z_), + 0, + 1, + 2, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_LossPath, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_LossPath, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_LossPath, intervals_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_LossPath, points_), + ~0u, + 0, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Metadata, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Metadata, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Metadata, decomposition_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Metadata, unit_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape_Metadata, step_per_epoch_), + 0, + 1, + 2, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape, landscape_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape, loss_path_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape, metadata_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::LossLandscape, convergence_point_), + 0, + 1, + 2, + 3, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Image, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Image, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Image, height_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Image, width_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Image, colorspace_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Image, encoded_image_), + 1, + 2, + 3, + 0, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram_bucket, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram_bucket, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram_bucket, left_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram_bucket, width_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram_bucket, count_), + 0, + 1, + 2, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, buckets_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, nan_count_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, pos_inf_count_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, neg_inf_count_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, max_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, min_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, sum_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Histogram, count_), + ~0u, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Value, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Value, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Value, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Value, tag_), + offsetof(::mindspore::irpb::Summary_ValueDefaultTypeInternal, scalar_value_), + offsetof(::mindspore::irpb::Summary_ValueDefaultTypeInternal, image_), + offsetof(::mindspore::irpb::Summary_ValueDefaultTypeInternal, tensor_), + offsetof(::mindspore::irpb::Summary_ValueDefaultTypeInternal, histogram_), + offsetof(::mindspore::irpb::Summary_ValueDefaultTypeInternal, loss_landscape_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary_Value, value_), + 0, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Summary, value_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, ground_truth_prob_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, predicted_label_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, predicted_prob_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, ground_truth_prob_sd_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, ground_truth_prob_itl95_low_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, ground_truth_prob_itl95_hi_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, predicted_prob_sd_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, predicted_prob_itl95_low_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Inference, predicted_prob_itl95_hi_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Explanation, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Explanation, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Explanation, explain_method_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Explanation, label_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Explanation, heatmap_path_), + 0, + 2, + 1, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Benchmark, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Benchmark, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Benchmark, benchmark_method_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Benchmark, explain_method_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Benchmark, total_score_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Benchmark, label_score_), + 0, + 1, + 2, + ~0u, + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Metadata, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Metadata, label_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Metadata, explain_method_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Metadata, benchmark_method_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_HocLayer, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_HocLayer, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_HocLayer, prob_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_HocLayer, box_), + 0, + ~0u, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Hoc, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Hoc, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Hoc, label_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Hoc, mask_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain_Hoc, layer_), + 1, + 0, + ~0u, + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, _has_bits_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, sample_id_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, image_path_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, ground_truth_label_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, inference_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, explanation_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, benchmark_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, metadata_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, status_), + PROTOBUF_FIELD_OFFSET(::mindspore::irpb::Explain, hoc_), + 4, + 0, + ~0u, + 2, + ~0u, + ~0u, + 3, + 1, + ~0u, +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, 12, sizeof(::mindspore::irpb::Event)}, + { 18, 26, sizeof(::mindspore::irpb::LossLandscape_Point)}, + { 29, 36, sizeof(::mindspore::irpb::LossLandscape_LossPath)}, + { 38, 46, sizeof(::mindspore::irpb::LossLandscape_Metadata)}, + { 49, 58, sizeof(::mindspore::irpb::LossLandscape)}, + { 62, 71, sizeof(::mindspore::irpb::Summary_Image)}, + { 75, 83, sizeof(::mindspore::irpb::Summary_Histogram_bucket)}, + { 86, 99, sizeof(::mindspore::irpb::Summary_Histogram)}, + { 107, 119, sizeof(::mindspore::irpb::Summary_Value)}, + { 125, -1, sizeof(::mindspore::irpb::Summary)}, + { 131, -1, sizeof(::mindspore::irpb::Explain_Inference)}, + { 145, 153, sizeof(::mindspore::irpb::Explain_Explanation)}, + { 156, 165, sizeof(::mindspore::irpb::Explain_Benchmark)}, + { 169, -1, sizeof(::mindspore::irpb::Explain_Metadata)}, + { 177, 184, sizeof(::mindspore::irpb::Explain_HocLayer)}, + { 186, 194, sizeof(::mindspore::irpb::Explain_Hoc)}, + { 197, 211, sizeof(::mindspore::irpb::Explain)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::mindspore::irpb::_Event_default_instance_), + reinterpret_cast(&::mindspore::irpb::_LossLandscape_Point_default_instance_), + reinterpret_cast(&::mindspore::irpb::_LossLandscape_LossPath_default_instance_), + reinterpret_cast(&::mindspore::irpb::_LossLandscape_Metadata_default_instance_), + reinterpret_cast(&::mindspore::irpb::_LossLandscape_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Summary_Image_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Summary_Histogram_bucket_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Summary_Histogram_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Summary_Value_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Summary_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_Inference_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_Explanation_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_Benchmark_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_Metadata_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_HocLayer_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_Hoc_default_instance_), + reinterpret_cast(&::mindspore::irpb::_Explain_default_instance_), +}; + +const char descriptor_table_protodef_mindspore_5fsummary_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\027mindspore_summary.proto\022\016mindspore.irp" + "b\032\026mindspore_anf_ir.proto\"\314\001\n\005Event\022\021\n\tw" + "all_time\030\001 \002(\001\022\014\n\004step\030\002 \001(\003\022\021\n\007version\030" + "\003 \001(\tH\000\022/\n\tgraph_def\030\004 \001(\0132\032.mindspore.i" + "rpb.GraphProtoH\000\022*\n\007summary\030\005 \001(\0132\027.mind" + "spore.irpb.SummaryH\000\022*\n\007explain\030\006 \001(\0132\027." + "mindspore.irpb.ExplainH\000B\006\n\004what\"\232\004\n\rLos" + "sLandscape\0226\n\tlandscape\030\001 \001(\0132#.mindspor" + "e.irpb.LossLandscape.Point\0229\n\tloss_path\030" + "\002 \001(\0132&.mindspore.irpb.LossLandscape.Los" + "sPath\0228\n\010metadata\030\003 \001(\0132&.mindspore.irpb" + ".LossLandscape.Metadata\022>\n\021convergence_p" + "oint\030\004 \001(\0132#.mindspore.irpb.LossLandscap" + "e.Point\032\177\n\005Point\022&\n\001x\030\001 \001(\0132\033.mindspore." + "irpb.TensorProto\022&\n\001y\030\002 \001(\0132\033.mindspore." + "irpb.TensorProto\022&\n\001z\030\003 \001(\0132\033.mindspore." + "irpb.TensorProto\032R\n\010LossPath\022\021\n\tinterval" + "s\030\001 \003(\005\0223\n\006points\030\002 \001(\0132#.mindspore.irpb" + ".LossLandscape.Point\032G\n\010Metadata\022\025\n\rdeco" + "mposition\030\001 \001(\t\022\014\n\004unit\030\002 \001(\t\022\026\n\016step_pe" + "r_epoch\030\003 \001(\005\"\210\005\n\007Summary\022,\n\005value\030\001 \003(\013" + "2\035.mindspore.irpb.Summary.Value\032Q\n\005Image" + "\022\016\n\006height\030\001 \002(\005\022\r\n\005width\030\002 \002(\005\022\022\n\ncolor" + "space\030\003 \002(\005\022\025\n\rencoded_image\030\004 \002(\014\032\363\001\n\tH" + "istogram\0229\n\007buckets\030\001 \003(\0132(.mindspore.ir" + "pb.Summary.Histogram.bucket\022\021\n\tnan_count" + "\030\002 \001(\003\022\025\n\rpos_inf_count\030\003 \001(\003\022\025\n\rneg_inf" + "_count\030\004 \001(\003\022\013\n\003max\030\005 \001(\001\022\013\n\003min\030\006 \001(\001\022\013" + "\n\003sum\030\007 \001(\001\022\r\n\005count\030\010 \001(\003\0324\n\006bucket\022\014\n\004" + "left\030\001 \002(\001\022\r\n\005width\030\002 \002(\001\022\r\n\005count\030\003 \002(\003" + "\032\205\002\n\005Value\022\013\n\003tag\030\001 \002(\t\022\026\n\014scalar_value\030" + "\003 \001(\002H\000\022.\n\005image\030\004 \001(\0132\035.mindspore.irpb." + "Summary.ImageH\000\022-\n\006tensor\030\010 \001(\0132\033.mindsp" + "ore.irpb.TensorProtoH\000\0226\n\thistogram\030\t \001(" + "\0132!.mindspore.irpb.Summary.HistogramH\000\0227" + "\n\016loss_landscape\030\n \001(\0132\035.mindspore.irpb." + "LossLandscapeH\000B\007\n\005value\"\375\007\n\007Explain\022\021\n\t" + "sample_id\030\001 \001(\005\022\022\n\nimage_path\030\002 \001(\t\022\032\n\022g" + "round_truth_label\030\003 \003(\005\0224\n\tinference\030\004 \001" + "(\0132!.mindspore.irpb.Explain.Inference\0228\n" + "\013explanation\030\005 \003(\0132#.mindspore.irpb.Expl" + "ain.Explanation\0224\n\tbenchmark\030\006 \003(\0132!.min" + "dspore.irpb.Explain.Benchmark\0222\n\010metadat" + "a\030\007 \001(\0132 .mindspore.irpb.Explain.Metadat" + "a\022\016\n\006status\030\010 \001(\t\022(\n\003hoc\030\t \003(\0132\033.mindspo" + "re.irpb.Explain.Hoc\032\234\002\n\tInference\022\031\n\021gro" + "und_truth_prob\030\001 \003(\002\022\027\n\017predicted_label\030" + "\002 \003(\005\022\026\n\016predicted_prob\030\003 \003(\002\022\034\n\024ground_" + "truth_prob_sd\030\004 \003(\002\022#\n\033ground_truth_prob" + "_itl95_low\030\005 \003(\002\022\"\n\032ground_truth_prob_it" + "l95_hi\030\006 \003(\002\022\031\n\021predicted_prob_sd\030\007 \003(\002\022" + " \n\030predicted_prob_itl95_low\030\010 \003(\002\022\037\n\027pre" + "dicted_prob_itl95_hi\030\t \003(\002\032J\n\013Explanatio" + "n\022\026\n\016explain_method\030\001 \001(\t\022\r\n\005label\030\002 \001(\005" + "\022\024\n\014heatmap_path\030\003 \001(\t\032g\n\tBenchmark\022\030\n\020b" + "enchmark_method\030\001 \001(\t\022\026\n\016explain_method\030" + "\002 \001(\t\022\023\n\013total_score\030\003 \001(\002\022\023\n\013label_scor" + "e\030\004 \003(\002\032K\n\010Metadata\022\r\n\005label\030\001 \003(\t\022\026\n\016ex" + "plain_method\030\002 \003(\t\022\030\n\020benchmark_method\030\003" + " \003(\t\032%\n\010HocLayer\022\014\n\004prob\030\001 \001(\002\022\013\n\003box\030\002 " + "\003(\005\032S\n\003Hoc\022\r\n\005label\030\001 \001(\005\022\014\n\004mask\030\002 \001(\t\022" + "/\n\005layer\030\003 \003(\0132 .mindspore.irpb.Explain." + "HocLayerB\003\370\001\001" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_mindspore_5fsummary_2eproto_deps[1] = { + &::descriptor_table_mindspore_5fanf_5fir_2eproto, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_mindspore_5fsummary_2eproto_sccs[17] = { + &scc_info_Event_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Benchmark_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Explanation_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Hoc_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_HocLayer_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Inference_mindspore_5fsummary_2eproto.base, + &scc_info_Explain_Metadata_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto.base, + &scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto.base, + &scc_info_Summary_mindspore_5fsummary_2eproto.base, + &scc_info_Summary_Histogram_mindspore_5fsummary_2eproto.base, + &scc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto.base, + &scc_info_Summary_Image_mindspore_5fsummary_2eproto.base, + &scc_info_Summary_Value_mindspore_5fsummary_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_mindspore_5fsummary_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_mindspore_5fsummary_2eproto = { + false, false, descriptor_table_protodef_mindspore_5fsummary_2eproto, "mindspore_summary.proto", 2493, + &descriptor_table_mindspore_5fsummary_2eproto_once, descriptor_table_mindspore_5fsummary_2eproto_sccs, descriptor_table_mindspore_5fsummary_2eproto_deps, 17, 1, + schemas, file_default_instances, TableStruct_mindspore_5fsummary_2eproto::offsets, + file_level_metadata_mindspore_5fsummary_2eproto, 17, file_level_enum_descriptors_mindspore_5fsummary_2eproto, file_level_service_descriptors_mindspore_5fsummary_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_mindspore_5fsummary_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_mindspore_5fsummary_2eproto)), true); +namespace mindspore { +namespace irpb { + +// =================================================================== + +void Event::InitAsDefaultInstance() { + ::mindspore::irpb::_Event_default_instance_.version_.UnsafeSetDefault( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::mindspore::irpb::_Event_default_instance_.graph_def_ = const_cast< ::mindspore::irpb::GraphProto*>( + ::mindspore::irpb::GraphProto::internal_default_instance()); + ::mindspore::irpb::_Event_default_instance_.summary_ = const_cast< ::mindspore::irpb::Summary*>( + ::mindspore::irpb::Summary::internal_default_instance()); + ::mindspore::irpb::_Event_default_instance_.explain_ = const_cast< ::mindspore::irpb::Explain*>( + ::mindspore::irpb::Explain::internal_default_instance()); +} +class Event::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_wall_time(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_step(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::GraphProto& graph_def(const Event* msg); + static const ::mindspore::irpb::Summary& summary(const Event* msg); + static const ::mindspore::irpb::Explain& explain(const Event* msg); + static bool MissingRequiredFields(const HasBits& has_bits) { + return ((has_bits[0] & 0x00000001) ^ 0x00000001) != 0; + } +}; + +const ::mindspore::irpb::GraphProto& +Event::_Internal::graph_def(const Event* msg) { + return *msg->what_.graph_def_; +} +const ::mindspore::irpb::Summary& +Event::_Internal::summary(const Event* msg) { + return *msg->what_.summary_; +} +const ::mindspore::irpb::Explain& +Event::_Internal::explain(const Event* msg) { + return *msg->what_.explain_; +} +void Event::set_allocated_graph_def(::mindspore::irpb::GraphProto* graph_def) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (graph_def) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(graph_def)->GetArena(); + if (message_arena != submessage_arena) { + graph_def = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, graph_def, submessage_arena); + } + set_has_graph_def(); + what_.graph_def_ = graph_def; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Event.graph_def) +} +void Event::clear_graph_def() { + if (_internal_has_graph_def()) { + if (GetArena() == nullptr) { + delete what_.graph_def_; + } + clear_has_what(); + } +} +void Event::set_allocated_summary(::mindspore::irpb::Summary* summary) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (summary) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(summary); + if (message_arena != submessage_arena) { + summary = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, summary, submessage_arena); + } + set_has_summary(); + what_.summary_ = summary; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Event.summary) +} +void Event::set_allocated_explain(::mindspore::irpb::Explain* explain) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_what(); + if (explain) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(explain); + if (message_arena != submessage_arena) { + explain = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, explain, submessage_arena); + } + set_has_explain(); + what_.explain_ = explain; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Event.explain) +} +Event::Event(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Event) +} +Event::Event(const Event& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::memcpy(&wall_time_, &from.wall_time_, + static_cast(reinterpret_cast(&step_) - + reinterpret_cast(&wall_time_)) + sizeof(step_)); + clear_has_what(); + switch (from.what_case()) { + case kVersion: { + _internal_set_version(from._internal_version()); + break; + } + case kGraphDef: { + _internal_mutable_graph_def()->::mindspore::irpb::GraphProto::MergeFrom(from._internal_graph_def()); + break; + } + case kSummary: { + _internal_mutable_summary()->::mindspore::irpb::Summary::MergeFrom(from._internal_summary()); + break; + } + case kExplain: { + _internal_mutable_explain()->::mindspore::irpb::Explain::MergeFrom(from._internal_explain()); + break; + } + case WHAT_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Event) +} + +void Event::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Event_mindspore_5fsummary_2eproto.base); + ::memset(&wall_time_, 0, static_cast( + reinterpret_cast(&step_) - + reinterpret_cast(&wall_time_)) + sizeof(step_)); + clear_has_what(); +} + +Event::~Event() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Event) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Event::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (has_what()) { + clear_what(); + } +} + +void Event::ArenaDtor(void* object) { + Event* _this = reinterpret_cast< Event* >(object); + (void)_this; +} +void Event::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Event::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Event& Event::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Event_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Event::clear_what() { +// @@protoc_insertion_point(one_of_clear_start:mindspore.irpb.Event) + switch (what_case()) { + case kVersion: { + what_.version_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + break; + } + case kGraphDef: { + if (GetArena() == nullptr) { + delete what_.graph_def_; + } + break; + } + case kSummary: { + if (GetArena() == nullptr) { + delete what_.summary_; + } + break; + } + case kExplain: { + if (GetArena() == nullptr) { + delete what_.explain_; + } + break; + } + case WHAT_NOT_SET: { + break; + } + } + _oneof_case_[0] = WHAT_NOT_SET; +} + + +void Event::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Event) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + ::memset(&wall_time_, 0, static_cast( + reinterpret_cast(&step_) - + reinterpret_cast(&wall_time_)) + sizeof(step_)); + } + clear_what(); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Event::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // required double wall_time = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 9)) { + _Internal::set_has_wall_time(&has_bits); + wall_time_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // optional int64 step = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_step(&has_bits); + step_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string version = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_version(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Event.version"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.GraphProto graph_def = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_graph_def(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.Summary summary = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ctx->ParseMessage(_internal_mutable_summary(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.Explain explain = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ctx->ParseMessage(_internal_mutable_explain(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Event::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Event) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // required double wall_time = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(1, this->_internal_wall_time(), target); + } + + // optional int64 step = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(2, this->_internal_step(), target); + } + + switch (what_case()) { + case kVersion: { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_version().data(), static_cast(this->_internal_version().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Event.version"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_version(), target); + break; + } + case kGraphDef: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::graph_def(this), target, stream); + break; + } + case kSummary: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 5, _Internal::summary(this), target, stream); + break; + } + case kExplain: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 6, _Internal::explain(this), target, stream); + break; + } + default: ; + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Event) + return target; +} + +size_t Event::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Event) + size_t total_size = 0; + + // required double wall_time = 1; + if (_internal_has_wall_time()) { + total_size += 1 + 8; + } + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // optional int64 step = 2; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_step()); + } + + switch (what_case()) { + // string version = 3; + case kVersion: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_version()); + break; + } + // .mindspore.irpb.GraphProto graph_def = 4; + case kGraphDef: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.graph_def_); + break; + } + // .mindspore.irpb.Summary summary = 5; + case kSummary: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.summary_); + break; + } + // .mindspore.irpb.Explain explain = 6; + case kExplain: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *what_.explain_); + break; + } + case WHAT_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Event::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Event) + GOOGLE_DCHECK_NE(&from, this); + const Event* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Event) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Event) + MergeFrom(*source); + } +} + +void Event::MergeFrom(const Event& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Event) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + wall_time_ = from.wall_time_; + } + if (cached_has_bits & 0x00000002u) { + step_ = from.step_; + } + _has_bits_[0] |= cached_has_bits; + } + switch (from.what_case()) { + case kVersion: { + _internal_set_version(from._internal_version()); + break; + } + case kGraphDef: { + _internal_mutable_graph_def()->::mindspore::irpb::GraphProto::MergeFrom(from._internal_graph_def()); + break; + } + case kSummary: { + _internal_mutable_summary()->::mindspore::irpb::Summary::MergeFrom(from._internal_summary()); + break; + } + case kExplain: { + _internal_mutable_explain()->::mindspore::irpb::Explain::MergeFrom(from._internal_explain()); + break; + } + case WHAT_NOT_SET: { + break; + } + } +} + +void Event::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Event) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Event::CopyFrom(const Event& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Event) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Event::IsInitialized() const { + if (_Internal::MissingRequiredFields(_has_bits_)) return false; + switch (what_case()) { + case kVersion: { + break; + } + case kGraphDef: { + break; + } + case kSummary: { + if (has_summary()) { + if (!this->summary().IsInitialized()) return false; + } + break; + } + case kExplain: { + break; + } + case WHAT_NOT_SET: { + break; + } + } + return true; +} + +void Event::InternalSwap(Event* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Event, step_) + + sizeof(Event::step_) + - PROTOBUF_FIELD_OFFSET(Event, wall_time_)>( + reinterpret_cast(&wall_time_), + reinterpret_cast(&other->wall_time_)); + swap(what_, other->what_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Event::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void LossLandscape_Point::InitAsDefaultInstance() { + ::mindspore::irpb::_LossLandscape_Point_default_instance_._instance.get_mutable()->x_ = const_cast< ::mindspore::irpb::TensorProto*>( + ::mindspore::irpb::TensorProto::internal_default_instance()); + ::mindspore::irpb::_LossLandscape_Point_default_instance_._instance.get_mutable()->y_ = const_cast< ::mindspore::irpb::TensorProto*>( + ::mindspore::irpb::TensorProto::internal_default_instance()); + ::mindspore::irpb::_LossLandscape_Point_default_instance_._instance.get_mutable()->z_ = const_cast< ::mindspore::irpb::TensorProto*>( + ::mindspore::irpb::TensorProto::internal_default_instance()); +} +class LossLandscape_Point::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static const ::mindspore::irpb::TensorProto& x(const LossLandscape_Point* msg); + static void set_has_x(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::TensorProto& y(const LossLandscape_Point* msg); + static void set_has_y(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::TensorProto& z(const LossLandscape_Point* msg); + static void set_has_z(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +const ::mindspore::irpb::TensorProto& +LossLandscape_Point::_Internal::x(const LossLandscape_Point* msg) { + return *msg->x_; +} +const ::mindspore::irpb::TensorProto& +LossLandscape_Point::_Internal::y(const LossLandscape_Point* msg) { + return *msg->y_; +} +const ::mindspore::irpb::TensorProto& +LossLandscape_Point::_Internal::z(const LossLandscape_Point* msg) { + return *msg->z_; +} +void LossLandscape_Point::clear_x() { + if (x_ != nullptr) x_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +void LossLandscape_Point::clear_y() { + if (y_ != nullptr) y_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +void LossLandscape_Point::clear_z() { + if (z_ != nullptr) z_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +LossLandscape_Point::LossLandscape_Point(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.LossLandscape.Point) +} +LossLandscape_Point::LossLandscape_Point(const LossLandscape_Point& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_x()) { + x_ = new ::mindspore::irpb::TensorProto(*from.x_); + } else { + x_ = nullptr; + } + if (from._internal_has_y()) { + y_ = new ::mindspore::irpb::TensorProto(*from.y_); + } else { + y_ = nullptr; + } + if (from._internal_has_z()) { + z_ = new ::mindspore::irpb::TensorProto(*from.z_); + } else { + z_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.LossLandscape.Point) +} + +void LossLandscape_Point::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto.base); + ::memset(&x_, 0, static_cast( + reinterpret_cast(&z_) - + reinterpret_cast(&x_)) + sizeof(z_)); +} + +LossLandscape_Point::~LossLandscape_Point() { + // @@protoc_insertion_point(destructor:mindspore.irpb.LossLandscape.Point) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void LossLandscape_Point::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete x_; + if (this != internal_default_instance()) delete y_; + if (this != internal_default_instance()) delete z_; +} + +void LossLandscape_Point::ArenaDtor(void* object) { + LossLandscape_Point* _this = reinterpret_cast< LossLandscape_Point* >(object); + (void)_this; +} +void LossLandscape_Point::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void LossLandscape_Point::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const LossLandscape_Point& LossLandscape_Point::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_LossLandscape_Point_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void LossLandscape_Point::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.LossLandscape.Point) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(x_ != nullptr); + x_->Clear(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(y_ != nullptr); + y_->Clear(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(z_ != nullptr); + z_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* LossLandscape_Point::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .mindspore.irpb.TensorProto x = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr = ctx->ParseMessage(_internal_mutable_x(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TensorProto y = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_y(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.TensorProto z = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ctx->ParseMessage(_internal_mutable_z(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* LossLandscape_Point::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.LossLandscape.Point) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.TensorProto x = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 1, _Internal::x(this), target, stream); + } + + // optional .mindspore.irpb.TensorProto y = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::y(this), target, stream); + } + + // optional .mindspore.irpb.TensorProto z = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 3, _Internal::z(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.LossLandscape.Point) + return target; +} + +size_t LossLandscape_Point::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.LossLandscape.Point) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional .mindspore.irpb.TensorProto x = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *x_); + } + + // optional .mindspore.irpb.TensorProto y = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *y_); + } + + // optional .mindspore.irpb.TensorProto z = 3; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *z_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void LossLandscape_Point::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.LossLandscape.Point) + GOOGLE_DCHECK_NE(&from, this); + const LossLandscape_Point* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.LossLandscape.Point) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.LossLandscape.Point) + MergeFrom(*source); + } +} + +void LossLandscape_Point::MergeFrom(const LossLandscape_Point& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.LossLandscape.Point) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_mutable_x()->::mindspore::irpb::TensorProto::MergeFrom(from._internal_x()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_y()->::mindspore::irpb::TensorProto::MergeFrom(from._internal_y()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_z()->::mindspore::irpb::TensorProto::MergeFrom(from._internal_z()); + } + } +} + +void LossLandscape_Point::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.LossLandscape.Point) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void LossLandscape_Point::CopyFrom(const LossLandscape_Point& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.LossLandscape.Point) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool LossLandscape_Point::IsInitialized() const { + return true; +} + +void LossLandscape_Point::InternalSwap(LossLandscape_Point* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(LossLandscape_Point, z_) + + sizeof(LossLandscape_Point::z_) + - PROTOBUF_FIELD_OFFSET(LossLandscape_Point, x_)>( + reinterpret_cast(&x_), + reinterpret_cast(&other->x_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata LossLandscape_Point::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void LossLandscape_LossPath::InitAsDefaultInstance() { + ::mindspore::irpb::_LossLandscape_LossPath_default_instance_._instance.get_mutable()->points_ = const_cast< ::mindspore::irpb::LossLandscape_Point*>( + ::mindspore::irpb::LossLandscape_Point::internal_default_instance()); +} +class LossLandscape_LossPath::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static const ::mindspore::irpb::LossLandscape_Point& points(const LossLandscape_LossPath* msg); + static void set_has_points(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::mindspore::irpb::LossLandscape_Point& +LossLandscape_LossPath::_Internal::points(const LossLandscape_LossPath* msg) { + return *msg->points_; +} +LossLandscape_LossPath::LossLandscape_LossPath(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + intervals_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.LossLandscape.LossPath) +} +LossLandscape_LossPath::LossLandscape_LossPath(const LossLandscape_LossPath& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + intervals_(from.intervals_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_points()) { + points_ = new ::mindspore::irpb::LossLandscape_Point(*from.points_); + } else { + points_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.LossLandscape.LossPath) +} + +void LossLandscape_LossPath::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto.base); + points_ = nullptr; +} + +LossLandscape_LossPath::~LossLandscape_LossPath() { + // @@protoc_insertion_point(destructor:mindspore.irpb.LossLandscape.LossPath) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void LossLandscape_LossPath::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete points_; +} + +void LossLandscape_LossPath::ArenaDtor(void* object) { + LossLandscape_LossPath* _this = reinterpret_cast< LossLandscape_LossPath* >(object); + (void)_this; +} +void LossLandscape_LossPath::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void LossLandscape_LossPath::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const LossLandscape_LossPath& LossLandscape_LossPath::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_LossLandscape_LossPath_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void LossLandscape_LossPath::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.LossLandscape.LossPath) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + intervals_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(points_ != nullptr); + points_->Clear(); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* LossLandscape_LossPath::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated int32 intervals = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_intervals(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<8>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_intervals(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.LossLandscape.Point points = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_points(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* LossLandscape_LossPath::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.LossLandscape.LossPath) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated int32 intervals = 1; + for (int i = 0, n = this->_internal_intervals_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(1, this->_internal_intervals(i), target); + } + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.LossLandscape.Point points = 2; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::points(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.LossLandscape.LossPath) + return target; +} + +size_t LossLandscape_LossPath::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.LossLandscape.LossPath) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated int32 intervals = 1; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->intervals_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_intervals_size()); + total_size += data_size; + } + + // optional .mindspore.irpb.LossLandscape.Point points = 2; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *points_); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void LossLandscape_LossPath::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.LossLandscape.LossPath) + GOOGLE_DCHECK_NE(&from, this); + const LossLandscape_LossPath* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.LossLandscape.LossPath) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.LossLandscape.LossPath) + MergeFrom(*source); + } +} + +void LossLandscape_LossPath::MergeFrom(const LossLandscape_LossPath& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.LossLandscape.LossPath) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + intervals_.MergeFrom(from.intervals_); + if (from._internal_has_points()) { + _internal_mutable_points()->::mindspore::irpb::LossLandscape_Point::MergeFrom(from._internal_points()); + } +} + +void LossLandscape_LossPath::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.LossLandscape.LossPath) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void LossLandscape_LossPath::CopyFrom(const LossLandscape_LossPath& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.LossLandscape.LossPath) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool LossLandscape_LossPath::IsInitialized() const { + return true; +} + +void LossLandscape_LossPath::InternalSwap(LossLandscape_LossPath* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + intervals_.InternalSwap(&other->intervals_); + swap(points_, other->points_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata LossLandscape_LossPath::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void LossLandscape_Metadata::InitAsDefaultInstance() { +} +class LossLandscape_Metadata::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_decomposition(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_unit(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_step_per_epoch(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +LossLandscape_Metadata::LossLandscape_Metadata(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.LossLandscape.Metadata) +} +LossLandscape_Metadata::LossLandscape_Metadata(const LossLandscape_Metadata& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + decomposition_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_decomposition()) { + decomposition_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_decomposition(), + GetArena()); + } + unit_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_unit()) { + unit_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_unit(), + GetArena()); + } + step_per_epoch_ = from.step_per_epoch_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.LossLandscape.Metadata) +} + +void LossLandscape_Metadata::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto.base); + decomposition_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + unit_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + step_per_epoch_ = 0; +} + +LossLandscape_Metadata::~LossLandscape_Metadata() { + // @@protoc_insertion_point(destructor:mindspore.irpb.LossLandscape.Metadata) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void LossLandscape_Metadata::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + decomposition_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + unit_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void LossLandscape_Metadata::ArenaDtor(void* object) { + LossLandscape_Metadata* _this = reinterpret_cast< LossLandscape_Metadata* >(object); + (void)_this; +} +void LossLandscape_Metadata::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void LossLandscape_Metadata::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const LossLandscape_Metadata& LossLandscape_Metadata::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_LossLandscape_Metadata_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void LossLandscape_Metadata::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.LossLandscape.Metadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + decomposition_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + unit_.ClearNonDefaultToEmpty(); + } + } + step_per_epoch_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* LossLandscape_Metadata::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string decomposition = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_decomposition(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.LossLandscape.Metadata.decomposition"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string unit = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_unit(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.LossLandscape.Metadata.unit"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int32 step_per_epoch = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_step_per_epoch(&has_bits); + step_per_epoch_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* LossLandscape_Metadata::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.LossLandscape.Metadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string decomposition = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_decomposition().data(), static_cast(this->_internal_decomposition().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.LossLandscape.Metadata.decomposition"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_decomposition(), target); + } + + // optional string unit = 2; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_unit().data(), static_cast(this->_internal_unit().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.LossLandscape.Metadata.unit"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_unit(), target); + } + + // optional int32 step_per_epoch = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(3, this->_internal_step_per_epoch(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.LossLandscape.Metadata) + return target; +} + +size_t LossLandscape_Metadata::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.LossLandscape.Metadata) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional string decomposition = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_decomposition()); + } + + // optional string unit = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_unit()); + } + + // optional int32 step_per_epoch = 3; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_step_per_epoch()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void LossLandscape_Metadata::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.LossLandscape.Metadata) + GOOGLE_DCHECK_NE(&from, this); + const LossLandscape_Metadata* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.LossLandscape.Metadata) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.LossLandscape.Metadata) + MergeFrom(*source); + } +} + +void LossLandscape_Metadata::MergeFrom(const LossLandscape_Metadata& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.LossLandscape.Metadata) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_decomposition(from._internal_decomposition()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_unit(from._internal_unit()); + } + if (cached_has_bits & 0x00000004u) { + step_per_epoch_ = from.step_per_epoch_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void LossLandscape_Metadata::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.LossLandscape.Metadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void LossLandscape_Metadata::CopyFrom(const LossLandscape_Metadata& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.LossLandscape.Metadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool LossLandscape_Metadata::IsInitialized() const { + return true; +} + +void LossLandscape_Metadata::InternalSwap(LossLandscape_Metadata* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + decomposition_.Swap(&other->decomposition_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + unit_.Swap(&other->unit_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(step_per_epoch_, other->step_per_epoch_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata LossLandscape_Metadata::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void LossLandscape::InitAsDefaultInstance() { + ::mindspore::irpb::_LossLandscape_default_instance_._instance.get_mutable()->landscape_ = const_cast< ::mindspore::irpb::LossLandscape_Point*>( + ::mindspore::irpb::LossLandscape_Point::internal_default_instance()); + ::mindspore::irpb::_LossLandscape_default_instance_._instance.get_mutable()->loss_path_ = const_cast< ::mindspore::irpb::LossLandscape_LossPath*>( + ::mindspore::irpb::LossLandscape_LossPath::internal_default_instance()); + ::mindspore::irpb::_LossLandscape_default_instance_._instance.get_mutable()->metadata_ = const_cast< ::mindspore::irpb::LossLandscape_Metadata*>( + ::mindspore::irpb::LossLandscape_Metadata::internal_default_instance()); + ::mindspore::irpb::_LossLandscape_default_instance_._instance.get_mutable()->convergence_point_ = const_cast< ::mindspore::irpb::LossLandscape_Point*>( + ::mindspore::irpb::LossLandscape_Point::internal_default_instance()); +} +class LossLandscape::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static const ::mindspore::irpb::LossLandscape_Point& landscape(const LossLandscape* msg); + static void set_has_landscape(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::LossLandscape_LossPath& loss_path(const LossLandscape* msg); + static void set_has_loss_path(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::mindspore::irpb::LossLandscape_Metadata& metadata(const LossLandscape* msg); + static void set_has_metadata(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static const ::mindspore::irpb::LossLandscape_Point& convergence_point(const LossLandscape* msg); + static void set_has_convergence_point(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } +}; + +const ::mindspore::irpb::LossLandscape_Point& +LossLandscape::_Internal::landscape(const LossLandscape* msg) { + return *msg->landscape_; +} +const ::mindspore::irpb::LossLandscape_LossPath& +LossLandscape::_Internal::loss_path(const LossLandscape* msg) { + return *msg->loss_path_; +} +const ::mindspore::irpb::LossLandscape_Metadata& +LossLandscape::_Internal::metadata(const LossLandscape* msg) { + return *msg->metadata_; +} +const ::mindspore::irpb::LossLandscape_Point& +LossLandscape::_Internal::convergence_point(const LossLandscape* msg) { + return *msg->convergence_point_; +} +LossLandscape::LossLandscape(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.LossLandscape) +} +LossLandscape::LossLandscape(const LossLandscape& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_landscape()) { + landscape_ = new ::mindspore::irpb::LossLandscape_Point(*from.landscape_); + } else { + landscape_ = nullptr; + } + if (from._internal_has_loss_path()) { + loss_path_ = new ::mindspore::irpb::LossLandscape_LossPath(*from.loss_path_); + } else { + loss_path_ = nullptr; + } + if (from._internal_has_metadata()) { + metadata_ = new ::mindspore::irpb::LossLandscape_Metadata(*from.metadata_); + } else { + metadata_ = nullptr; + } + if (from._internal_has_convergence_point()) { + convergence_point_ = new ::mindspore::irpb::LossLandscape_Point(*from.convergence_point_); + } else { + convergence_point_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.LossLandscape) +} + +void LossLandscape::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_LossLandscape_mindspore_5fsummary_2eproto.base); + ::memset(&landscape_, 0, static_cast( + reinterpret_cast(&convergence_point_) - + reinterpret_cast(&landscape_)) + sizeof(convergence_point_)); +} + +LossLandscape::~LossLandscape() { + // @@protoc_insertion_point(destructor:mindspore.irpb.LossLandscape) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void LossLandscape::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete landscape_; + if (this != internal_default_instance()) delete loss_path_; + if (this != internal_default_instance()) delete metadata_; + if (this != internal_default_instance()) delete convergence_point_; +} + +void LossLandscape::ArenaDtor(void* object) { + LossLandscape* _this = reinterpret_cast< LossLandscape* >(object); + (void)_this; +} +void LossLandscape::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void LossLandscape::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const LossLandscape& LossLandscape::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_LossLandscape_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void LossLandscape::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.LossLandscape) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(landscape_ != nullptr); + landscape_->Clear(); + } + if (cached_has_bits & 0x00000002u) { + GOOGLE_DCHECK(loss_path_ != nullptr); + loss_path_->Clear(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(metadata_ != nullptr); + metadata_->Clear(); + } + if (cached_has_bits & 0x00000008u) { + GOOGLE_DCHECK(convergence_point_ != nullptr); + convergence_point_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* LossLandscape::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .mindspore.irpb.LossLandscape.Point landscape = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr = ctx->ParseMessage(_internal_mutable_landscape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.LossLandscape.LossPath loss_path = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_loss_path(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.LossLandscape.Metadata metadata = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ctx->ParseMessage(_internal_mutable_metadata(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.LossLandscape.Point convergence_point = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_convergence_point(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* LossLandscape::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.LossLandscape) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .mindspore.irpb.LossLandscape.Point landscape = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 1, _Internal::landscape(this), target, stream); + } + + // optional .mindspore.irpb.LossLandscape.LossPath loss_path = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::loss_path(this), target, stream); + } + + // optional .mindspore.irpb.LossLandscape.Metadata metadata = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 3, _Internal::metadata(this), target, stream); + } + + // optional .mindspore.irpb.LossLandscape.Point convergence_point = 4; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::convergence_point(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.LossLandscape) + return target; +} + +size_t LossLandscape::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.LossLandscape) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + // optional .mindspore.irpb.LossLandscape.Point landscape = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *landscape_); + } + + // optional .mindspore.irpb.LossLandscape.LossPath loss_path = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *loss_path_); + } + + // optional .mindspore.irpb.LossLandscape.Metadata metadata = 3; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *metadata_); + } + + // optional .mindspore.irpb.LossLandscape.Point convergence_point = 4; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *convergence_point_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void LossLandscape::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.LossLandscape) + GOOGLE_DCHECK_NE(&from, this); + const LossLandscape* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.LossLandscape) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.LossLandscape) + MergeFrom(*source); + } +} + +void LossLandscape::MergeFrom(const LossLandscape& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.LossLandscape) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + _internal_mutable_landscape()->::mindspore::irpb::LossLandscape_Point::MergeFrom(from._internal_landscape()); + } + if (cached_has_bits & 0x00000002u) { + _internal_mutable_loss_path()->::mindspore::irpb::LossLandscape_LossPath::MergeFrom(from._internal_loss_path()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_metadata()->::mindspore::irpb::LossLandscape_Metadata::MergeFrom(from._internal_metadata()); + } + if (cached_has_bits & 0x00000008u) { + _internal_mutable_convergence_point()->::mindspore::irpb::LossLandscape_Point::MergeFrom(from._internal_convergence_point()); + } + } +} + +void LossLandscape::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.LossLandscape) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void LossLandscape::CopyFrom(const LossLandscape& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.LossLandscape) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool LossLandscape::IsInitialized() const { + return true; +} + +void LossLandscape::InternalSwap(LossLandscape* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(LossLandscape, convergence_point_) + + sizeof(LossLandscape::convergence_point_) + - PROTOBUF_FIELD_OFFSET(LossLandscape, landscape_)>( + reinterpret_cast(&landscape_), + reinterpret_cast(&other->landscape_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata LossLandscape::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Image::InitAsDefaultInstance() { +} +class Summary_Image::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_height(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_width(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_colorspace(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_encoded_image(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static bool MissingRequiredFields(const HasBits& has_bits) { + return ((has_bits[0] & 0x0000000f) ^ 0x0000000f) != 0; + } +}; + +Summary_Image::Summary_Image(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Summary.Image) +} +Summary_Image::Summary_Image(const Summary_Image& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + encoded_image_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_encoded_image()) { + encoded_image_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_encoded_image(), + GetArena()); + } + ::memcpy(&height_, &from.height_, + static_cast(reinterpret_cast(&colorspace_) - + reinterpret_cast(&height_)) + sizeof(colorspace_)); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Summary.Image) +} + +void Summary_Image::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_Image_mindspore_5fsummary_2eproto.base); + encoded_image_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&height_, 0, static_cast( + reinterpret_cast(&colorspace_) - + reinterpret_cast(&height_)) + sizeof(colorspace_)); +} + +Summary_Image::~Summary_Image() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Summary.Image) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Image::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + encoded_image_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void Summary_Image::ArenaDtor(void* object) { + Summary_Image* _this = reinterpret_cast< Summary_Image* >(object); + (void)_this; +} +void Summary_Image::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Image::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Image& Summary_Image::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Image_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Image::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Summary.Image) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + encoded_image_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x0000000eu) { + ::memset(&height_, 0, static_cast( + reinterpret_cast(&colorspace_) - + reinterpret_cast(&height_)) + sizeof(colorspace_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Image::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // required int32 height = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_height(&has_bits); + height_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // required int32 width = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_width(&has_bits); + width_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // required int32 colorspace = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_colorspace(&has_bits); + colorspace_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // required bytes encoded_image = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_encoded_image(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Image::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Summary.Image) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // required int32 height = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(1, this->_internal_height(), target); + } + + // required int32 width = 2; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(2, this->_internal_width(), target); + } + + // required int32 colorspace = 3; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(3, this->_internal_colorspace(), target); + } + + // required bytes encoded_image = 4; + if (cached_has_bits & 0x00000001u) { + target = stream->WriteBytesMaybeAliased( + 4, this->_internal_encoded_image(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Summary.Image) + return target; +} + +size_t Summary_Image::RequiredFieldsByteSizeFallback() const { +// @@protoc_insertion_point(required_fields_byte_size_fallback_start:mindspore.irpb.Summary.Image) + size_t total_size = 0; + + if (_internal_has_encoded_image()) { + // required bytes encoded_image = 4; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_encoded_image()); + } + + if (_internal_has_height()) { + // required int32 height = 1; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_height()); + } + + if (_internal_has_width()) { + // required int32 width = 2; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_width()); + } + + if (_internal_has_colorspace()) { + // required int32 colorspace = 3; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_colorspace()); + } + + return total_size; +} +size_t Summary_Image::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Summary.Image) + size_t total_size = 0; + + if (((_has_bits_[0] & 0x0000000f) ^ 0x0000000f) == 0) { // All required fields are present. + // required bytes encoded_image = 4; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_encoded_image()); + + // required int32 height = 1; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_height()); + + // required int32 width = 2; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_width()); + + // required int32 colorspace = 3; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_colorspace()); + + } else { + total_size += RequiredFieldsByteSizeFallback(); + } + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Image::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Summary.Image) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Image* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Summary.Image) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Summary.Image) + MergeFrom(*source); + } +} + +void Summary_Image::MergeFrom(const Summary_Image& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Summary.Image) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_encoded_image(from._internal_encoded_image()); + } + if (cached_has_bits & 0x00000002u) { + height_ = from.height_; + } + if (cached_has_bits & 0x00000004u) { + width_ = from.width_; + } + if (cached_has_bits & 0x00000008u) { + colorspace_ = from.colorspace_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Summary_Image::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Summary.Image) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Image::CopyFrom(const Summary_Image& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Summary.Image) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Image::IsInitialized() const { + if (_Internal::MissingRequiredFields(_has_bits_)) return false; + return true; +} + +void Summary_Image::InternalSwap(Summary_Image* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + encoded_image_.Swap(&other->encoded_image_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Summary_Image, colorspace_) + + sizeof(Summary_Image::colorspace_) + - PROTOBUF_FIELD_OFFSET(Summary_Image, height_)>( + reinterpret_cast(&height_), + reinterpret_cast(&other->height_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Image::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Histogram_bucket::InitAsDefaultInstance() { +} +class Summary_Histogram_bucket::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_left(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_width(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_count(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static bool MissingRequiredFields(const HasBits& has_bits) { + return ((has_bits[0] & 0x00000007) ^ 0x00000007) != 0; + } +}; + +Summary_Histogram_bucket::Summary_Histogram_bucket(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Summary.Histogram.bucket) +} +Summary_Histogram_bucket::Summary_Histogram_bucket(const Summary_Histogram_bucket& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::memcpy(&left_, &from.left_, + static_cast(reinterpret_cast(&count_) - + reinterpret_cast(&left_)) + sizeof(count_)); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Summary.Histogram.bucket) +} + +void Summary_Histogram_bucket::SharedCtor() { + ::memset(&left_, 0, static_cast( + reinterpret_cast(&count_) - + reinterpret_cast(&left_)) + sizeof(count_)); +} + +Summary_Histogram_bucket::~Summary_Histogram_bucket() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Summary.Histogram.bucket) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Histogram_bucket::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Summary_Histogram_bucket::ArenaDtor(void* object) { + Summary_Histogram_bucket* _this = reinterpret_cast< Summary_Histogram_bucket* >(object); + (void)_this; +} +void Summary_Histogram_bucket::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Histogram_bucket::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Histogram_bucket& Summary_Histogram_bucket::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Histogram_bucket_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Histogram_bucket::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Summary.Histogram.bucket) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + ::memset(&left_, 0, static_cast( + reinterpret_cast(&count_) - + reinterpret_cast(&left_)) + sizeof(count_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Histogram_bucket::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // required double left = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 9)) { + _Internal::set_has_left(&has_bits); + left_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // required double width = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 17)) { + _Internal::set_has_width(&has_bits); + width_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // required int64 count = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_count(&has_bits); + count_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Histogram_bucket::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Summary.Histogram.bucket) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // required double left = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(1, this->_internal_left(), target); + } + + // required double width = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(2, this->_internal_width(), target); + } + + // required int64 count = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->_internal_count(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Summary.Histogram.bucket) + return target; +} + +size_t Summary_Histogram_bucket::RequiredFieldsByteSizeFallback() const { +// @@protoc_insertion_point(required_fields_byte_size_fallback_start:mindspore.irpb.Summary.Histogram.bucket) + size_t total_size = 0; + + if (_internal_has_left()) { + // required double left = 1; + total_size += 1 + 8; + } + + if (_internal_has_width()) { + // required double width = 2; + total_size += 1 + 8; + } + + if (_internal_has_count()) { + // required int64 count = 3; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_count()); + } + + return total_size; +} +size_t Summary_Histogram_bucket::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Summary.Histogram.bucket) + size_t total_size = 0; + + if (((_has_bits_[0] & 0x00000007) ^ 0x00000007) == 0) { // All required fields are present. + // required double left = 1; + total_size += 1 + 8; + + // required double width = 2; + total_size += 1 + 8; + + // required int64 count = 3; + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_count()); + + } else { + total_size += RequiredFieldsByteSizeFallback(); + } + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Histogram_bucket::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Summary.Histogram.bucket) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Histogram_bucket* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Summary.Histogram.bucket) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Summary.Histogram.bucket) + MergeFrom(*source); + } +} + +void Summary_Histogram_bucket::MergeFrom(const Summary_Histogram_bucket& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Summary.Histogram.bucket) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + left_ = from.left_; + } + if (cached_has_bits & 0x00000002u) { + width_ = from.width_; + } + if (cached_has_bits & 0x00000004u) { + count_ = from.count_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Summary_Histogram_bucket::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Summary.Histogram.bucket) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Histogram_bucket::CopyFrom(const Summary_Histogram_bucket& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Summary.Histogram.bucket) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Histogram_bucket::IsInitialized() const { + if (_Internal::MissingRequiredFields(_has_bits_)) return false; + return true; +} + +void Summary_Histogram_bucket::InternalSwap(Summary_Histogram_bucket* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Summary_Histogram_bucket, count_) + + sizeof(Summary_Histogram_bucket::count_) + - PROTOBUF_FIELD_OFFSET(Summary_Histogram_bucket, left_)>( + reinterpret_cast(&left_), + reinterpret_cast(&other->left_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Histogram_bucket::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Histogram::InitAsDefaultInstance() { +} +class Summary_Histogram::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_nan_count(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_pos_inf_count(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_neg_inf_count(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_max(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_min(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } + static void set_has_sum(HasBits* has_bits) { + (*has_bits)[0] |= 32u; + } + static void set_has_count(HasBits* has_bits) { + (*has_bits)[0] |= 64u; + } +}; + +Summary_Histogram::Summary_Histogram(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + buckets_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Summary.Histogram) +} +Summary_Histogram::Summary_Histogram(const Summary_Histogram& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + buckets_(from.buckets_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::memcpy(&nan_count_, &from.nan_count_, + static_cast(reinterpret_cast(&count_) - + reinterpret_cast(&nan_count_)) + sizeof(count_)); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Summary.Histogram) +} + +void Summary_Histogram::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_Histogram_mindspore_5fsummary_2eproto.base); + ::memset(&nan_count_, 0, static_cast( + reinterpret_cast(&count_) - + reinterpret_cast(&nan_count_)) + sizeof(count_)); +} + +Summary_Histogram::~Summary_Histogram() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Summary.Histogram) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Histogram::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Summary_Histogram::ArenaDtor(void* object) { + Summary_Histogram* _this = reinterpret_cast< Summary_Histogram* >(object); + (void)_this; +} +void Summary_Histogram::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Histogram::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Histogram& Summary_Histogram::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Histogram_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Histogram::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Summary.Histogram) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + buckets_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + ::memset(&nan_count_, 0, static_cast( + reinterpret_cast(&count_) - + reinterpret_cast(&nan_count_)) + sizeof(count_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Histogram::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.Summary.Histogram.bucket buckets = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_buckets(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // optional int64 nan_count = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_nan_count(&has_bits); + nan_count_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 pos_inf_count = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_pos_inf_count(&has_bits); + pos_inf_count_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 neg_inf_count = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32)) { + _Internal::set_has_neg_inf_count(&has_bits); + neg_inf_count_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional double max = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 41)) { + _Internal::set_has_max(&has_bits); + max_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // optional double min = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 49)) { + _Internal::set_has_min(&has_bits); + min_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // optional double sum = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 57)) { + _Internal::set_has_sum(&has_bits); + sum_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // optional int64 count = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 64)) { + _Internal::set_has_count(&has_bits); + count_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Histogram::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Summary.Histogram) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.Summary.Histogram.bucket buckets = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_buckets_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_buckets(i), target, stream); + } + + cached_has_bits = _has_bits_[0]; + // optional int64 nan_count = 2; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(2, this->_internal_nan_count(), target); + } + + // optional int64 pos_inf_count = 3; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->_internal_pos_inf_count(), target); + } + + // optional int64 neg_inf_count = 4; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(4, this->_internal_neg_inf_count(), target); + } + + // optional double max = 5; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(5, this->_internal_max(), target); + } + + // optional double min = 6; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(6, this->_internal_min(), target); + } + + // optional double sum = 7; + if (cached_has_bits & 0x00000020u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteDoubleToArray(7, this->_internal_sum(), target); + } + + // optional int64 count = 8; + if (cached_has_bits & 0x00000040u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(8, this->_internal_count(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Summary.Histogram) + return target; +} + +size_t Summary_Histogram::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Summary.Histogram) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.Summary.Histogram.bucket buckets = 1; + total_size += 1UL * this->_internal_buckets_size(); + for (const auto& msg : this->buckets_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + // optional int64 nan_count = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_nan_count()); + } + + // optional int64 pos_inf_count = 3; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_pos_inf_count()); + } + + // optional int64 neg_inf_count = 4; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_neg_inf_count()); + } + + // optional double max = 5; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + 8; + } + + // optional double min = 6; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + 8; + } + + // optional double sum = 7; + if (cached_has_bits & 0x00000020u) { + total_size += 1 + 8; + } + + // optional int64 count = 8; + if (cached_has_bits & 0x00000040u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_count()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Histogram::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Summary.Histogram) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Histogram* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Summary.Histogram) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Summary.Histogram) + MergeFrom(*source); + } +} + +void Summary_Histogram::MergeFrom(const Summary_Histogram& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Summary.Histogram) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + buckets_.MergeFrom(from.buckets_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + if (cached_has_bits & 0x00000001u) { + nan_count_ = from.nan_count_; + } + if (cached_has_bits & 0x00000002u) { + pos_inf_count_ = from.pos_inf_count_; + } + if (cached_has_bits & 0x00000004u) { + neg_inf_count_ = from.neg_inf_count_; + } + if (cached_has_bits & 0x00000008u) { + max_ = from.max_; + } + if (cached_has_bits & 0x00000010u) { + min_ = from.min_; + } + if (cached_has_bits & 0x00000020u) { + sum_ = from.sum_; + } + if (cached_has_bits & 0x00000040u) { + count_ = from.count_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Summary_Histogram::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Summary.Histogram) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Histogram::CopyFrom(const Summary_Histogram& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Summary.Histogram) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Histogram::IsInitialized() const { + if (!::PROTOBUF_NAMESPACE_ID::internal::AllAreInitialized(buckets_)) return false; + return true; +} + +void Summary_Histogram::InternalSwap(Summary_Histogram* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + buckets_.InternalSwap(&other->buckets_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Summary_Histogram, count_) + + sizeof(Summary_Histogram::count_) + - PROTOBUF_FIELD_OFFSET(Summary_Histogram, nan_count_)>( + reinterpret_cast(&nan_count_), + reinterpret_cast(&other->nan_count_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Histogram::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Value::InitAsDefaultInstance() { + ::mindspore::irpb::_Summary_Value_default_instance_.scalar_value_ = 0; + ::mindspore::irpb::_Summary_Value_default_instance_.image_ = const_cast< ::mindspore::irpb::Summary_Image*>( + ::mindspore::irpb::Summary_Image::internal_default_instance()); + ::mindspore::irpb::_Summary_Value_default_instance_.tensor_ = const_cast< ::mindspore::irpb::TensorProto*>( + ::mindspore::irpb::TensorProto::internal_default_instance()); + ::mindspore::irpb::_Summary_Value_default_instance_.histogram_ = const_cast< ::mindspore::irpb::Summary_Histogram*>( + ::mindspore::irpb::Summary_Histogram::internal_default_instance()); + ::mindspore::irpb::_Summary_Value_default_instance_.loss_landscape_ = const_cast< ::mindspore::irpb::LossLandscape*>( + ::mindspore::irpb::LossLandscape::internal_default_instance()); +} +class Summary_Value::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_tag(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::Summary_Image& image(const Summary_Value* msg); + static const ::mindspore::irpb::TensorProto& tensor(const Summary_Value* msg); + static const ::mindspore::irpb::Summary_Histogram& histogram(const Summary_Value* msg); + static const ::mindspore::irpb::LossLandscape& loss_landscape(const Summary_Value* msg); + static bool MissingRequiredFields(const HasBits& has_bits) { + return ((has_bits[0] & 0x00000001) ^ 0x00000001) != 0; + } +}; + +const ::mindspore::irpb::Summary_Image& +Summary_Value::_Internal::image(const Summary_Value* msg) { + return *msg->value_.image_; +} +const ::mindspore::irpb::TensorProto& +Summary_Value::_Internal::tensor(const Summary_Value* msg) { + return *msg->value_.tensor_; +} +const ::mindspore::irpb::Summary_Histogram& +Summary_Value::_Internal::histogram(const Summary_Value* msg) { + return *msg->value_.histogram_; +} +const ::mindspore::irpb::LossLandscape& +Summary_Value::_Internal::loss_landscape(const Summary_Value* msg) { + return *msg->value_.loss_landscape_; +} +void Summary_Value::set_allocated_image(::mindspore::irpb::Summary_Image* image) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (image) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(image); + if (message_arena != submessage_arena) { + image = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, image, submessage_arena); + } + set_has_image(); + value_.image_ = image; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Summary.Value.image) +} +void Summary_Value::set_allocated_tensor(::mindspore::irpb::TensorProto* tensor) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (tensor) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(tensor)->GetArena(); + if (message_arena != submessage_arena) { + tensor = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tensor, submessage_arena); + } + set_has_tensor(); + value_.tensor_ = tensor; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Summary.Value.tensor) +} +void Summary_Value::clear_tensor() { + if (_internal_has_tensor()) { + if (GetArena() == nullptr) { + delete value_.tensor_; + } + clear_has_value(); + } +} +void Summary_Value::set_allocated_histogram(::mindspore::irpb::Summary_Histogram* histogram) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (histogram) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(histogram); + if (message_arena != submessage_arena) { + histogram = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, histogram, submessage_arena); + } + set_has_histogram(); + value_.histogram_ = histogram; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Summary.Value.histogram) +} +void Summary_Value::set_allocated_loss_landscape(::mindspore::irpb::LossLandscape* loss_landscape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (loss_landscape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(loss_landscape); + if (message_arena != submessage_arena) { + loss_landscape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, loss_landscape, submessage_arena); + } + set_has_loss_landscape(); + value_.loss_landscape_ = loss_landscape; + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Summary.Value.loss_landscape) +} +Summary_Value::Summary_Value(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Summary.Value) +} +Summary_Value::Summary_Value(const Summary_Value& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + tag_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_tag()) { + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_tag(), + GetArena()); + } + clear_has_value(); + switch (from.value_case()) { + case kScalarValue: { + _internal_set_scalar_value(from._internal_scalar_value()); + break; + } + case kImage: { + _internal_mutable_image()->::mindspore::irpb::Summary_Image::MergeFrom(from._internal_image()); + break; + } + case kTensor: { + _internal_mutable_tensor()->::mindspore::irpb::TensorProto::MergeFrom(from._internal_tensor()); + break; + } + case kHistogram: { + _internal_mutable_histogram()->::mindspore::irpb::Summary_Histogram::MergeFrom(from._internal_histogram()); + break; + } + case kLossLandscape: { + _internal_mutable_loss_landscape()->::mindspore::irpb::LossLandscape::MergeFrom(from._internal_loss_landscape()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Summary.Value) +} + +void Summary_Value::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_Value_mindspore_5fsummary_2eproto.base); + tag_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + clear_has_value(); +} + +Summary_Value::~Summary_Value() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Summary.Value) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Value::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + tag_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (has_value()) { + clear_value(); + } +} + +void Summary_Value::ArenaDtor(void* object) { + Summary_Value* _this = reinterpret_cast< Summary_Value* >(object); + (void)_this; +} +void Summary_Value::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Value::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Value& Summary_Value::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Value_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Value::clear_value() { +// @@protoc_insertion_point(one_of_clear_start:mindspore.irpb.Summary.Value) + switch (value_case()) { + case kScalarValue: { + // No need to clear + break; + } + case kImage: { + if (GetArena() == nullptr) { + delete value_.image_; + } + break; + } + case kTensor: { + if (GetArena() == nullptr) { + delete value_.tensor_; + } + break; + } + case kHistogram: { + if (GetArena() == nullptr) { + delete value_.histogram_; + } + break; + } + case kLossLandscape: { + if (GetArena() == nullptr) { + delete value_.loss_landscape_; + } + break; + } + case VALUE_NOT_SET: { + break; + } + } + _oneof_case_[0] = VALUE_NOT_SET; +} + + +void Summary_Value::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Summary.Value) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + tag_.ClearNonDefaultToEmpty(); + } + clear_value(); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Value::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // required string tag = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_tag(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Summary.Value.tag"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // float scalar_value = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 29)) { + _internal_set_scalar_value(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // .mindspore.irpb.Summary.Image image = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_image(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.TensorProto tensor = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr = ctx->ParseMessage(_internal_mutable_tensor(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.Summary.Histogram histogram = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + ptr = ctx->ParseMessage(_internal_mutable_histogram(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .mindspore.irpb.LossLandscape loss_landscape = 10; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + ptr = ctx->ParseMessage(_internal_mutable_loss_landscape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Value::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Summary.Value) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // required string tag = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_tag().data(), static_cast(this->_internal_tag().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Summary.Value.tag"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_tag(), target); + } + + switch (value_case()) { + case kScalarValue: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(3, this->_internal_scalar_value(), target); + break; + } + case kImage: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::image(this), target, stream); + break; + } + case kTensor: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 8, _Internal::tensor(this), target, stream); + break; + } + case kHistogram: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 9, _Internal::histogram(this), target, stream); + break; + } + case kLossLandscape: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 10, _Internal::loss_landscape(this), target, stream); + break; + } + default: ; + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Summary.Value) + return target; +} + +size_t Summary_Value::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Summary.Value) + size_t total_size = 0; + + // required string tag = 1; + if (_internal_has_tag()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_tag()); + } + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + switch (value_case()) { + // float scalar_value = 3; + case kScalarValue: { + total_size += 1 + 4; + break; + } + // .mindspore.irpb.Summary.Image image = 4; + case kImage: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.image_); + break; + } + // .mindspore.irpb.TensorProto tensor = 8; + case kTensor: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.tensor_); + break; + } + // .mindspore.irpb.Summary.Histogram histogram = 9; + case kHistogram: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.histogram_); + break; + } + // .mindspore.irpb.LossLandscape loss_landscape = 10; + case kLossLandscape: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.loss_landscape_); + break; + } + case VALUE_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Value::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Summary.Value) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Value* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Summary.Value) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Summary.Value) + MergeFrom(*source); + } +} + +void Summary_Value::MergeFrom(const Summary_Value& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Summary.Value) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from._internal_has_tag()) { + _internal_set_tag(from._internal_tag()); + } + switch (from.value_case()) { + case kScalarValue: { + _internal_set_scalar_value(from._internal_scalar_value()); + break; + } + case kImage: { + _internal_mutable_image()->::mindspore::irpb::Summary_Image::MergeFrom(from._internal_image()); + break; + } + case kTensor: { + _internal_mutable_tensor()->::mindspore::irpb::TensorProto::MergeFrom(from._internal_tensor()); + break; + } + case kHistogram: { + _internal_mutable_histogram()->::mindspore::irpb::Summary_Histogram::MergeFrom(from._internal_histogram()); + break; + } + case kLossLandscape: { + _internal_mutable_loss_landscape()->::mindspore::irpb::LossLandscape::MergeFrom(from._internal_loss_landscape()); + break; + } + case VALUE_NOT_SET: { + break; + } + } +} + +void Summary_Value::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Summary.Value) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Value::CopyFrom(const Summary_Value& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Summary.Value) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Value::IsInitialized() const { + if (_Internal::MissingRequiredFields(_has_bits_)) return false; + switch (value_case()) { + case kScalarValue: { + break; + } + case kImage: { + if (has_image()) { + if (!this->image().IsInitialized()) return false; + } + break; + } + case kTensor: { + break; + } + case kHistogram: { + if (has_histogram()) { + if (!this->histogram().IsInitialized()) return false; + } + break; + } + case kLossLandscape: { + break; + } + case VALUE_NOT_SET: { + break; + } + } + return true; +} + +void Summary_Value::InternalSwap(Summary_Value* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + tag_.Swap(&other->tag_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(value_, other->value_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Value::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary::InitAsDefaultInstance() { +} +class Summary::_Internal { + public: +}; + +Summary::Summary(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + value_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Summary) +} +Summary::Summary(const Summary& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + value_(from.value_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Summary) +} + +void Summary::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_mindspore_5fsummary_2eproto.base); +} + +Summary::~Summary() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Summary) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Summary::ArenaDtor(void* object) { + Summary* _this = reinterpret_cast< Summary* >(object); + (void)_this; +} +void Summary::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary& Summary::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Summary) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + value_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .mindspore.irpb.Summary.Value value = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_value(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Summary) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .mindspore.irpb.Summary.Value value = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_value_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_value(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Summary) + return target; +} + +size_t Summary::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Summary) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.Summary.Value value = 1; + total_size += 1UL * this->_internal_value_size(); + for (const auto& msg : this->value_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Summary) + GOOGLE_DCHECK_NE(&from, this); + const Summary* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Summary) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Summary) + MergeFrom(*source); + } +} + +void Summary::MergeFrom(const Summary& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Summary) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + value_.MergeFrom(from.value_); +} + +void Summary::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Summary) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary::CopyFrom(const Summary& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Summary) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary::IsInitialized() const { + if (!::PROTOBUF_NAMESPACE_ID::internal::AllAreInitialized(value_)) return false; + return true; +} + +void Summary::InternalSwap(Summary* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + value_.InternalSwap(&other->value_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain_Inference::InitAsDefaultInstance() { +} +class Explain_Inference::_Internal { + public: +}; + +Explain_Inference::Explain_Inference(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + ground_truth_prob_(arena), + predicted_label_(arena), + predicted_prob_(arena), + ground_truth_prob_sd_(arena), + ground_truth_prob_itl95_low_(arena), + ground_truth_prob_itl95_hi_(arena), + predicted_prob_sd_(arena), + predicted_prob_itl95_low_(arena), + predicted_prob_itl95_hi_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain.Inference) +} +Explain_Inference::Explain_Inference(const Explain_Inference& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + ground_truth_prob_(from.ground_truth_prob_), + predicted_label_(from.predicted_label_), + predicted_prob_(from.predicted_prob_), + ground_truth_prob_sd_(from.ground_truth_prob_sd_), + ground_truth_prob_itl95_low_(from.ground_truth_prob_itl95_low_), + ground_truth_prob_itl95_hi_(from.ground_truth_prob_itl95_hi_), + predicted_prob_sd_(from.predicted_prob_sd_), + predicted_prob_itl95_low_(from.predicted_prob_itl95_low_), + predicted_prob_itl95_hi_(from.predicted_prob_itl95_hi_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain.Inference) +} + +void Explain_Inference::SharedCtor() { +} + +Explain_Inference::~Explain_Inference() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain.Inference) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain_Inference::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Explain_Inference::ArenaDtor(void* object) { + Explain_Inference* _this = reinterpret_cast< Explain_Inference* >(object); + (void)_this; +} +void Explain_Inference::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain_Inference::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain_Inference& Explain_Inference::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_Inference_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain_Inference::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain.Inference) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + ground_truth_prob_.Clear(); + predicted_label_.Clear(); + predicted_prob_.Clear(); + ground_truth_prob_sd_.Clear(); + ground_truth_prob_itl95_low_.Clear(); + ground_truth_prob_itl95_hi_.Clear(); + predicted_prob_sd_.Clear(); + predicted_prob_itl95_low_.Clear(); + predicted_prob_itl95_hi_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain_Inference::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated float ground_truth_prob = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 13)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_ground_truth_prob(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<13>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_ground_truth_prob(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated int32 predicted_label = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_predicted_label(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<16>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_predicted_label(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float predicted_prob = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 29)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_predicted_prob(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<29>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_predicted_prob(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float ground_truth_prob_sd = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 37)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_ground_truth_prob_sd(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<37>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_ground_truth_prob_sd(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float ground_truth_prob_itl95_low = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 45)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_ground_truth_prob_itl95_low(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<45>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_ground_truth_prob_itl95_low(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float ground_truth_prob_itl95_hi = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 53)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_ground_truth_prob_itl95_hi(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<53>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_ground_truth_prob_itl95_hi(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float predicted_prob_sd = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 61)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_predicted_prob_sd(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<61>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_predicted_prob_sd(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float predicted_prob_itl95_low = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 69)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_predicted_prob_itl95_low(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<69>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_predicted_prob_itl95_low(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float predicted_prob_itl95_hi = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 77)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_predicted_prob_itl95_hi(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<77>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_predicted_prob_itl95_hi(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain_Inference::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain.Inference) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated float ground_truth_prob = 1; + for (int i = 0, n = this->_internal_ground_truth_prob_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(1, this->_internal_ground_truth_prob(i), target); + } + + // repeated int32 predicted_label = 2; + for (int i = 0, n = this->_internal_predicted_label_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(2, this->_internal_predicted_label(i), target); + } + + // repeated float predicted_prob = 3; + for (int i = 0, n = this->_internal_predicted_prob_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(3, this->_internal_predicted_prob(i), target); + } + + // repeated float ground_truth_prob_sd = 4; + for (int i = 0, n = this->_internal_ground_truth_prob_sd_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(4, this->_internal_ground_truth_prob_sd(i), target); + } + + // repeated float ground_truth_prob_itl95_low = 5; + for (int i = 0, n = this->_internal_ground_truth_prob_itl95_low_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(5, this->_internal_ground_truth_prob_itl95_low(i), target); + } + + // repeated float ground_truth_prob_itl95_hi = 6; + for (int i = 0, n = this->_internal_ground_truth_prob_itl95_hi_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(6, this->_internal_ground_truth_prob_itl95_hi(i), target); + } + + // repeated float predicted_prob_sd = 7; + for (int i = 0, n = this->_internal_predicted_prob_sd_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(7, this->_internal_predicted_prob_sd(i), target); + } + + // repeated float predicted_prob_itl95_low = 8; + for (int i = 0, n = this->_internal_predicted_prob_itl95_low_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(8, this->_internal_predicted_prob_itl95_low(i), target); + } + + // repeated float predicted_prob_itl95_hi = 9; + for (int i = 0, n = this->_internal_predicted_prob_itl95_hi_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(9, this->_internal_predicted_prob_itl95_hi(i), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain.Inference) + return target; +} + +size_t Explain_Inference::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain.Inference) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated float ground_truth_prob = 1; + { + unsigned int count = static_cast(this->_internal_ground_truth_prob_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_ground_truth_prob_size()); + total_size += data_size; + } + + // repeated int32 predicted_label = 2; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->predicted_label_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_predicted_label_size()); + total_size += data_size; + } + + // repeated float predicted_prob = 3; + { + unsigned int count = static_cast(this->_internal_predicted_prob_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_predicted_prob_size()); + total_size += data_size; + } + + // repeated float ground_truth_prob_sd = 4; + { + unsigned int count = static_cast(this->_internal_ground_truth_prob_sd_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_ground_truth_prob_sd_size()); + total_size += data_size; + } + + // repeated float ground_truth_prob_itl95_low = 5; + { + unsigned int count = static_cast(this->_internal_ground_truth_prob_itl95_low_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_ground_truth_prob_itl95_low_size()); + total_size += data_size; + } + + // repeated float ground_truth_prob_itl95_hi = 6; + { + unsigned int count = static_cast(this->_internal_ground_truth_prob_itl95_hi_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_ground_truth_prob_itl95_hi_size()); + total_size += data_size; + } + + // repeated float predicted_prob_sd = 7; + { + unsigned int count = static_cast(this->_internal_predicted_prob_sd_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_predicted_prob_sd_size()); + total_size += data_size; + } + + // repeated float predicted_prob_itl95_low = 8; + { + unsigned int count = static_cast(this->_internal_predicted_prob_itl95_low_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_predicted_prob_itl95_low_size()); + total_size += data_size; + } + + // repeated float predicted_prob_itl95_hi = 9; + { + unsigned int count = static_cast(this->_internal_predicted_prob_itl95_hi_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_predicted_prob_itl95_hi_size()); + total_size += data_size; + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain_Inference::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain.Inference) + GOOGLE_DCHECK_NE(&from, this); + const Explain_Inference* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain.Inference) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain.Inference) + MergeFrom(*source); + } +} + +void Explain_Inference::MergeFrom(const Explain_Inference& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain.Inference) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + ground_truth_prob_.MergeFrom(from.ground_truth_prob_); + predicted_label_.MergeFrom(from.predicted_label_); + predicted_prob_.MergeFrom(from.predicted_prob_); + ground_truth_prob_sd_.MergeFrom(from.ground_truth_prob_sd_); + ground_truth_prob_itl95_low_.MergeFrom(from.ground_truth_prob_itl95_low_); + ground_truth_prob_itl95_hi_.MergeFrom(from.ground_truth_prob_itl95_hi_); + predicted_prob_sd_.MergeFrom(from.predicted_prob_sd_); + predicted_prob_itl95_low_.MergeFrom(from.predicted_prob_itl95_low_); + predicted_prob_itl95_hi_.MergeFrom(from.predicted_prob_itl95_hi_); +} + +void Explain_Inference::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain.Inference) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain_Inference::CopyFrom(const Explain_Inference& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain.Inference) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain_Inference::IsInitialized() const { + return true; +} + +void Explain_Inference::InternalSwap(Explain_Inference* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + ground_truth_prob_.InternalSwap(&other->ground_truth_prob_); + predicted_label_.InternalSwap(&other->predicted_label_); + predicted_prob_.InternalSwap(&other->predicted_prob_); + ground_truth_prob_sd_.InternalSwap(&other->ground_truth_prob_sd_); + ground_truth_prob_itl95_low_.InternalSwap(&other->ground_truth_prob_itl95_low_); + ground_truth_prob_itl95_hi_.InternalSwap(&other->ground_truth_prob_itl95_hi_); + predicted_prob_sd_.InternalSwap(&other->predicted_prob_sd_); + predicted_prob_itl95_low_.InternalSwap(&other->predicted_prob_itl95_low_); + predicted_prob_itl95_hi_.InternalSwap(&other->predicted_prob_itl95_hi_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain_Inference::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain_Explanation::InitAsDefaultInstance() { +} +class Explain_Explanation::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_explain_method(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_label(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_heatmap_path(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +Explain_Explanation::Explain_Explanation(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain.Explanation) +} +Explain_Explanation::Explain_Explanation(const Explain_Explanation& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + explain_method_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_explain_method()) { + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_explain_method(), + GetArena()); + } + heatmap_path_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_heatmap_path()) { + heatmap_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_heatmap_path(), + GetArena()); + } + label_ = from.label_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain.Explanation) +} + +void Explain_Explanation::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Explain_Explanation_mindspore_5fsummary_2eproto.base); + explain_method_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + heatmap_path_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + label_ = 0; +} + +Explain_Explanation::~Explain_Explanation() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain.Explanation) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain_Explanation::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + explain_method_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + heatmap_path_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void Explain_Explanation::ArenaDtor(void* object) { + Explain_Explanation* _this = reinterpret_cast< Explain_Explanation* >(object); + (void)_this; +} +void Explain_Explanation::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain_Explanation::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain_Explanation& Explain_Explanation::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_Explanation_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain_Explanation::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain.Explanation) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + explain_method_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + heatmap_path_.ClearNonDefaultToEmpty(); + } + } + label_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain_Explanation::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string explain_method = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_explain_method(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Explanation.explain_method"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int32 label = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_label(&has_bits); + label_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string heatmap_path = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_heatmap_path(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Explanation.heatmap_path"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain_Explanation::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain.Explanation) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string explain_method = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_explain_method().data(), static_cast(this->_internal_explain_method().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Explanation.explain_method"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_explain_method(), target); + } + + // optional int32 label = 2; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(2, this->_internal_label(), target); + } + + // optional string heatmap_path = 3; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_heatmap_path().data(), static_cast(this->_internal_heatmap_path().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Explanation.heatmap_path"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_heatmap_path(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain.Explanation) + return target; +} + +size_t Explain_Explanation::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain.Explanation) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional string explain_method = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_explain_method()); + } + + // optional string heatmap_path = 3; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_heatmap_path()); + } + + // optional int32 label = 2; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_label()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain_Explanation::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain.Explanation) + GOOGLE_DCHECK_NE(&from, this); + const Explain_Explanation* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain.Explanation) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain.Explanation) + MergeFrom(*source); + } +} + +void Explain_Explanation::MergeFrom(const Explain_Explanation& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain.Explanation) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_explain_method(from._internal_explain_method()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_heatmap_path(from._internal_heatmap_path()); + } + if (cached_has_bits & 0x00000004u) { + label_ = from.label_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Explain_Explanation::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain.Explanation) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain_Explanation::CopyFrom(const Explain_Explanation& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain.Explanation) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain_Explanation::IsInitialized() const { + return true; +} + +void Explain_Explanation::InternalSwap(Explain_Explanation* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + explain_method_.Swap(&other->explain_method_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + heatmap_path_.Swap(&other->heatmap_path_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(label_, other->label_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain_Explanation::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain_Benchmark::InitAsDefaultInstance() { +} +class Explain_Benchmark::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_benchmark_method(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_explain_method(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_total_score(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +Explain_Benchmark::Explain_Benchmark(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + label_score_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain.Benchmark) +} +Explain_Benchmark::Explain_Benchmark(const Explain_Benchmark& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + label_score_(from.label_score_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + benchmark_method_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_benchmark_method()) { + benchmark_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_benchmark_method(), + GetArena()); + } + explain_method_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_explain_method()) { + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_explain_method(), + GetArena()); + } + total_score_ = from.total_score_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain.Benchmark) +} + +void Explain_Benchmark::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Explain_Benchmark_mindspore_5fsummary_2eproto.base); + benchmark_method_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + explain_method_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + total_score_ = 0; +} + +Explain_Benchmark::~Explain_Benchmark() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain.Benchmark) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain_Benchmark::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + benchmark_method_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + explain_method_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void Explain_Benchmark::ArenaDtor(void* object) { + Explain_Benchmark* _this = reinterpret_cast< Explain_Benchmark* >(object); + (void)_this; +} +void Explain_Benchmark::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain_Benchmark::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain_Benchmark& Explain_Benchmark::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_Benchmark_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain_Benchmark::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain.Benchmark) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + label_score_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + benchmark_method_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + explain_method_.ClearNonDefaultToEmpty(); + } + } + total_score_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain_Benchmark::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string benchmark_method = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_benchmark_method(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Benchmark.benchmark_method"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string explain_method = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_explain_method(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Benchmark.explain_method"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional float total_score = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 29)) { + _Internal::set_has_total_score(&has_bits); + total_score_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // repeated float label_score = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 37)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_label_score(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<37>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_label_score(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain_Benchmark::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain.Benchmark) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string benchmark_method = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_benchmark_method().data(), static_cast(this->_internal_benchmark_method().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Benchmark.benchmark_method"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_benchmark_method(), target); + } + + // optional string explain_method = 2; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_explain_method().data(), static_cast(this->_internal_explain_method().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Benchmark.explain_method"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_explain_method(), target); + } + + // optional float total_score = 3; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(3, this->_internal_total_score(), target); + } + + // repeated float label_score = 4; + for (int i = 0, n = this->_internal_label_score_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(4, this->_internal_label_score(i), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain.Benchmark) + return target; +} + +size_t Explain_Benchmark::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain.Benchmark) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated float label_score = 4; + { + unsigned int count = static_cast(this->_internal_label_score_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_label_score_size()); + total_size += data_size; + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional string benchmark_method = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_benchmark_method()); + } + + // optional string explain_method = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_explain_method()); + } + + // optional float total_score = 3; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + 4; + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain_Benchmark::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain.Benchmark) + GOOGLE_DCHECK_NE(&from, this); + const Explain_Benchmark* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain.Benchmark) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain.Benchmark) + MergeFrom(*source); + } +} + +void Explain_Benchmark::MergeFrom(const Explain_Benchmark& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain.Benchmark) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + label_score_.MergeFrom(from.label_score_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_benchmark_method(from._internal_benchmark_method()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_explain_method(from._internal_explain_method()); + } + if (cached_has_bits & 0x00000004u) { + total_score_ = from.total_score_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Explain_Benchmark::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain.Benchmark) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain_Benchmark::CopyFrom(const Explain_Benchmark& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain.Benchmark) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain_Benchmark::IsInitialized() const { + return true; +} + +void Explain_Benchmark::InternalSwap(Explain_Benchmark* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + label_score_.InternalSwap(&other->label_score_); + benchmark_method_.Swap(&other->benchmark_method_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + explain_method_.Swap(&other->explain_method_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(total_score_, other->total_score_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain_Benchmark::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain_Metadata::InitAsDefaultInstance() { +} +class Explain_Metadata::_Internal { + public: +}; + +Explain_Metadata::Explain_Metadata(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + label_(arena), + explain_method_(arena), + benchmark_method_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain.Metadata) +} +Explain_Metadata::Explain_Metadata(const Explain_Metadata& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + label_(from.label_), + explain_method_(from.explain_method_), + benchmark_method_(from.benchmark_method_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain.Metadata) +} + +void Explain_Metadata::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Explain_Metadata_mindspore_5fsummary_2eproto.base); +} + +Explain_Metadata::~Explain_Metadata() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain.Metadata) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain_Metadata::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Explain_Metadata::ArenaDtor(void* object) { + Explain_Metadata* _this = reinterpret_cast< Explain_Metadata* >(object); + (void)_this; +} +void Explain_Metadata::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain_Metadata::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain_Metadata& Explain_Metadata::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_Metadata_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain_Metadata::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain.Metadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + label_.Clear(); + explain_method_.Clear(); + benchmark_method_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain_Metadata::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated string label = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_label(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Metadata.label"); + #endif // !NDEBUG + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // repeated string explain_method = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_explain_method(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Metadata.explain_method"); + #endif // !NDEBUG + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<18>(ptr)); + } else goto handle_unusual; + continue; + // repeated string benchmark_method = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_benchmark_method(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Metadata.benchmark_method"); + #endif // !NDEBUG + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<26>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain_Metadata::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain.Metadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated string label = 1; + for (int i = 0, n = this->_internal_label_size(); i < n; i++) { + const auto& s = this->_internal_label(i); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + s.data(), static_cast(s.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Metadata.label"); + target = stream->WriteString(1, s, target); + } + + // repeated string explain_method = 2; + for (int i = 0, n = this->_internal_explain_method_size(); i < n; i++) { + const auto& s = this->_internal_explain_method(i); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + s.data(), static_cast(s.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Metadata.explain_method"); + target = stream->WriteString(2, s, target); + } + + // repeated string benchmark_method = 3; + for (int i = 0, n = this->_internal_benchmark_method_size(); i < n; i++) { + const auto& s = this->_internal_benchmark_method(i); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + s.data(), static_cast(s.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Metadata.benchmark_method"); + target = stream->WriteString(3, s, target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain.Metadata) + return target; +} + +size_t Explain_Metadata::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain.Metadata) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated string label = 1; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(label_.size()); + for (int i = 0, n = label_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + label_.Get(i)); + } + + // repeated string explain_method = 2; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(explain_method_.size()); + for (int i = 0, n = explain_method_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + explain_method_.Get(i)); + } + + // repeated string benchmark_method = 3; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(benchmark_method_.size()); + for (int i = 0, n = benchmark_method_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + benchmark_method_.Get(i)); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain_Metadata::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain.Metadata) + GOOGLE_DCHECK_NE(&from, this); + const Explain_Metadata* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain.Metadata) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain.Metadata) + MergeFrom(*source); + } +} + +void Explain_Metadata::MergeFrom(const Explain_Metadata& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain.Metadata) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + label_.MergeFrom(from.label_); + explain_method_.MergeFrom(from.explain_method_); + benchmark_method_.MergeFrom(from.benchmark_method_); +} + +void Explain_Metadata::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain.Metadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain_Metadata::CopyFrom(const Explain_Metadata& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain.Metadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain_Metadata::IsInitialized() const { + return true; +} + +void Explain_Metadata::InternalSwap(Explain_Metadata* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + label_.InternalSwap(&other->label_); + explain_method_.InternalSwap(&other->explain_method_); + benchmark_method_.InternalSwap(&other->benchmark_method_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain_Metadata::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain_HocLayer::InitAsDefaultInstance() { +} +class Explain_HocLayer::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_prob(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +Explain_HocLayer::Explain_HocLayer(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + box_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain.HocLayer) +} +Explain_HocLayer::Explain_HocLayer(const Explain_HocLayer& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + box_(from.box_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + prob_ = from.prob_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain.HocLayer) +} + +void Explain_HocLayer::SharedCtor() { + prob_ = 0; +} + +Explain_HocLayer::~Explain_HocLayer() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain.HocLayer) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain_HocLayer::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Explain_HocLayer::ArenaDtor(void* object) { + Explain_HocLayer* _this = reinterpret_cast< Explain_HocLayer* >(object); + (void)_this; +} +void Explain_HocLayer::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain_HocLayer::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain_HocLayer& Explain_HocLayer::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_HocLayer_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain_HocLayer::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain.HocLayer) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + box_.Clear(); + prob_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain_HocLayer::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional float prob = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 13)) { + _Internal::set_has_prob(&has_bits); + prob_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // repeated int32 box = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_box(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<16>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_box(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain_HocLayer::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain.HocLayer) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional float prob = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(1, this->_internal_prob(), target); + } + + // repeated int32 box = 2; + for (int i = 0, n = this->_internal_box_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(2, this->_internal_box(i), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain.HocLayer) + return target; +} + +size_t Explain_HocLayer::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain.HocLayer) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated int32 box = 2; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->box_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_box_size()); + total_size += data_size; + } + + // optional float prob = 1; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + 4; + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain_HocLayer::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain.HocLayer) + GOOGLE_DCHECK_NE(&from, this); + const Explain_HocLayer* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain.HocLayer) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain.HocLayer) + MergeFrom(*source); + } +} + +void Explain_HocLayer::MergeFrom(const Explain_HocLayer& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain.HocLayer) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + box_.MergeFrom(from.box_); + if (from._internal_has_prob()) { + _internal_set_prob(from._internal_prob()); + } +} + +void Explain_HocLayer::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain.HocLayer) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain_HocLayer::CopyFrom(const Explain_HocLayer& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain.HocLayer) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain_HocLayer::IsInitialized() const { + return true; +} + +void Explain_HocLayer::InternalSwap(Explain_HocLayer* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + box_.InternalSwap(&other->box_); + swap(prob_, other->prob_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain_HocLayer::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain_Hoc::InitAsDefaultInstance() { +} +class Explain_Hoc::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_label(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_mask(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +Explain_Hoc::Explain_Hoc(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + layer_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain.Hoc) +} +Explain_Hoc::Explain_Hoc(const Explain_Hoc& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + layer_(from.layer_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + mask_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_mask()) { + mask_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_mask(), + GetArena()); + } + label_ = from.label_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain.Hoc) +} + +void Explain_Hoc::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Explain_Hoc_mindspore_5fsummary_2eproto.base); + mask_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + label_ = 0; +} + +Explain_Hoc::~Explain_Hoc() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain.Hoc) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain_Hoc::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + mask_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void Explain_Hoc::ArenaDtor(void* object) { + Explain_Hoc* _this = reinterpret_cast< Explain_Hoc* >(object); + (void)_this; +} +void Explain_Hoc::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain_Hoc::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain_Hoc& Explain_Hoc::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_Hoc_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain_Hoc::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain.Hoc) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + layer_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + mask_.ClearNonDefaultToEmpty(); + } + label_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain_Hoc::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional int32 label = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_label(&has_bits); + label_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string mask = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_mask(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.Hoc.mask"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.Explain.HocLayer layer = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_layer(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<26>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain_Hoc::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain.Hoc) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional int32 label = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(1, this->_internal_label(), target); + } + + // optional string mask = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_mask().data(), static_cast(this->_internal_mask().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.Hoc.mask"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_mask(), target); + } + + // repeated .mindspore.irpb.Explain.HocLayer layer = 3; + for (unsigned int i = 0, + n = static_cast(this->_internal_layer_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(3, this->_internal_layer(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain.Hoc) + return target; +} + +size_t Explain_Hoc::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain.Hoc) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .mindspore.irpb.Explain.HocLayer layer = 3; + total_size += 1UL * this->_internal_layer_size(); + for (const auto& msg : this->layer_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string mask = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_mask()); + } + + // optional int32 label = 1; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_label()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain_Hoc::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain.Hoc) + GOOGLE_DCHECK_NE(&from, this); + const Explain_Hoc* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain.Hoc) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain.Hoc) + MergeFrom(*source); + } +} + +void Explain_Hoc::MergeFrom(const Explain_Hoc& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain.Hoc) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + layer_.MergeFrom(from.layer_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_mask(from._internal_mask()); + } + if (cached_has_bits & 0x00000002u) { + label_ = from.label_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Explain_Hoc::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain.Hoc) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain_Hoc::CopyFrom(const Explain_Hoc& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain.Hoc) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain_Hoc::IsInitialized() const { + return true; +} + +void Explain_Hoc::InternalSwap(Explain_Hoc* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + layer_.InternalSwap(&other->layer_); + mask_.Swap(&other->mask_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(label_, other->label_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain_Hoc::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Explain::InitAsDefaultInstance() { + ::mindspore::irpb::_Explain_default_instance_._instance.get_mutable()->inference_ = const_cast< ::mindspore::irpb::Explain_Inference*>( + ::mindspore::irpb::Explain_Inference::internal_default_instance()); + ::mindspore::irpb::_Explain_default_instance_._instance.get_mutable()->metadata_ = const_cast< ::mindspore::irpb::Explain_Metadata*>( + ::mindspore::irpb::Explain_Metadata::internal_default_instance()); +} +class Explain::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_sample_id(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } + static void set_has_image_path(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::mindspore::irpb::Explain_Inference& inference(const Explain* msg); + static void set_has_inference(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static const ::mindspore::irpb::Explain_Metadata& metadata(const Explain* msg); + static void set_has_metadata(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_status(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +const ::mindspore::irpb::Explain_Inference& +Explain::_Internal::inference(const Explain* msg) { + return *msg->inference_; +} +const ::mindspore::irpb::Explain_Metadata& +Explain::_Internal::metadata(const Explain* msg) { + return *msg->metadata_; +} +Explain::Explain(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + ground_truth_label_(arena), + explanation_(arena), + benchmark_(arena), + hoc_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:mindspore.irpb.Explain) +} +Explain::Explain(const Explain& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + ground_truth_label_(from.ground_truth_label_), + explanation_(from.explanation_), + benchmark_(from.benchmark_), + hoc_(from.hoc_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + image_path_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_image_path()) { + image_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_image_path(), + GetArena()); + } + status_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_status()) { + status_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_status(), + GetArena()); + } + if (from._internal_has_inference()) { + inference_ = new ::mindspore::irpb::Explain_Inference(*from.inference_); + } else { + inference_ = nullptr; + } + if (from._internal_has_metadata()) { + metadata_ = new ::mindspore::irpb::Explain_Metadata(*from.metadata_); + } else { + metadata_ = nullptr; + } + sample_id_ = from.sample_id_; + // @@protoc_insertion_point(copy_constructor:mindspore.irpb.Explain) +} + +void Explain::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Explain_mindspore_5fsummary_2eproto.base); + image_path_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + status_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&inference_, 0, static_cast( + reinterpret_cast(&sample_id_) - + reinterpret_cast(&inference_)) + sizeof(sample_id_)); +} + +Explain::~Explain() { + // @@protoc_insertion_point(destructor:mindspore.irpb.Explain) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Explain::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + image_path_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + status_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete inference_; + if (this != internal_default_instance()) delete metadata_; +} + +void Explain::ArenaDtor(void* object) { + Explain* _this = reinterpret_cast< Explain* >(object); + (void)_this; +} +void Explain::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Explain::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Explain& Explain::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Explain_mindspore_5fsummary_2eproto.base); + return *internal_default_instance(); +} + + +void Explain::Clear() { +// @@protoc_insertion_point(message_clear_start:mindspore.irpb.Explain) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + ground_truth_label_.Clear(); + explanation_.Clear(); + benchmark_.Clear(); + hoc_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + image_path_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + status_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(inference_ != nullptr); + inference_->Clear(); + } + if (cached_has_bits & 0x00000008u) { + GOOGLE_DCHECK(metadata_ != nullptr); + metadata_->Clear(); + } + } + sample_id_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Explain::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional int32 sample_id = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_sample_id(&has_bits); + sample_id_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string image_path = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_image_path(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.image_path"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated int32 ground_truth_label = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_ground_truth_label(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<24>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_ground_truth_label(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.Explain.Inference inference = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_inference(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.Explain.Explanation explanation = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_explanation(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<42>(ptr)); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.Explain.Benchmark benchmark = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_benchmark(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<50>(ptr)); + } else goto handle_unusual; + continue; + // optional .mindspore.irpb.Explain.Metadata metadata = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ctx->ParseMessage(_internal_mutable_metadata(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string status = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + auto str = _internal_mutable_status(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "mindspore.irpb.Explain.status"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .mindspore.irpb.Explain.Hoc hoc = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_hoc(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<74>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Explain::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:mindspore.irpb.Explain) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional int32 sample_id = 1; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(1, this->_internal_sample_id(), target); + } + + // optional string image_path = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_image_path().data(), static_cast(this->_internal_image_path().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.image_path"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_image_path(), target); + } + + // repeated int32 ground_truth_label = 3; + for (int i = 0, n = this->_internal_ground_truth_label_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(3, this->_internal_ground_truth_label(i), target); + } + + // optional .mindspore.irpb.Explain.Inference inference = 4; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::inference(this), target, stream); + } + + // repeated .mindspore.irpb.Explain.Explanation explanation = 5; + for (unsigned int i = 0, + n = static_cast(this->_internal_explanation_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(5, this->_internal_explanation(i), target, stream); + } + + // repeated .mindspore.irpb.Explain.Benchmark benchmark = 6; + for (unsigned int i = 0, + n = static_cast(this->_internal_benchmark_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(6, this->_internal_benchmark(i), target, stream); + } + + // optional .mindspore.irpb.Explain.Metadata metadata = 7; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 7, _Internal::metadata(this), target, stream); + } + + // optional string status = 8; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_status().data(), static_cast(this->_internal_status().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "mindspore.irpb.Explain.status"); + target = stream->WriteStringMaybeAliased( + 8, this->_internal_status(), target); + } + + // repeated .mindspore.irpb.Explain.Hoc hoc = 9; + for (unsigned int i = 0, + n = static_cast(this->_internal_hoc_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(9, this->_internal_hoc(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:mindspore.irpb.Explain) + return target; +} + +size_t Explain::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:mindspore.irpb.Explain) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated int32 ground_truth_label = 3; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->ground_truth_label_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_ground_truth_label_size()); + total_size += data_size; + } + + // repeated .mindspore.irpb.Explain.Explanation explanation = 5; + total_size += 1UL * this->_internal_explanation_size(); + for (const auto& msg : this->explanation_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.Explain.Benchmark benchmark = 6; + total_size += 1UL * this->_internal_benchmark_size(); + for (const auto& msg : this->benchmark_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .mindspore.irpb.Explain.Hoc hoc = 9; + total_size += 1UL * this->_internal_hoc_size(); + for (const auto& msg : this->hoc_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + // optional string image_path = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_image_path()); + } + + // optional string status = 8; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_status()); + } + + // optional .mindspore.irpb.Explain.Inference inference = 4; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *inference_); + } + + // optional .mindspore.irpb.Explain.Metadata metadata = 7; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *metadata_); + } + + // optional int32 sample_id = 1; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_sample_id()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Explain::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:mindspore.irpb.Explain) + GOOGLE_DCHECK_NE(&from, this); + const Explain* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:mindspore.irpb.Explain) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:mindspore.irpb.Explain) + MergeFrom(*source); + } +} + +void Explain::MergeFrom(const Explain& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:mindspore.irpb.Explain) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + ground_truth_label_.MergeFrom(from.ground_truth_label_); + explanation_.MergeFrom(from.explanation_); + benchmark_.MergeFrom(from.benchmark_); + hoc_.MergeFrom(from.hoc_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_image_path(from._internal_image_path()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_status(from._internal_status()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_inference()->::mindspore::irpb::Explain_Inference::MergeFrom(from._internal_inference()); + } + if (cached_has_bits & 0x00000008u) { + _internal_mutable_metadata()->::mindspore::irpb::Explain_Metadata::MergeFrom(from._internal_metadata()); + } + if (cached_has_bits & 0x00000010u) { + sample_id_ = from.sample_id_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void Explain::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:mindspore.irpb.Explain) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Explain::CopyFrom(const Explain& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:mindspore.irpb.Explain) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Explain::IsInitialized() const { + return true; +} + +void Explain::InternalSwap(Explain* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ground_truth_label_.InternalSwap(&other->ground_truth_label_); + explanation_.InternalSwap(&other->explanation_); + benchmark_.InternalSwap(&other->benchmark_); + hoc_.InternalSwap(&other->hoc_); + image_path_.Swap(&other->image_path_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + status_.Swap(&other->status_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Explain, sample_id_) + + sizeof(Explain::sample_id_) + - PROTOBUF_FIELD_OFFSET(Explain, inference_)>( + reinterpret_cast(&inference_), + reinterpret_cast(&other->inference_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Explain::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace irpb +} // namespace mindspore +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Event* Arena::CreateMaybeMessage< ::mindspore::irpb::Event >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Event >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::LossLandscape_Point* Arena::CreateMaybeMessage< ::mindspore::irpb::LossLandscape_Point >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::LossLandscape_Point >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::LossLandscape_LossPath* Arena::CreateMaybeMessage< ::mindspore::irpb::LossLandscape_LossPath >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::LossLandscape_LossPath >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::LossLandscape_Metadata* Arena::CreateMaybeMessage< ::mindspore::irpb::LossLandscape_Metadata >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::LossLandscape_Metadata >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::LossLandscape* Arena::CreateMaybeMessage< ::mindspore::irpb::LossLandscape >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::LossLandscape >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Summary_Image* Arena::CreateMaybeMessage< ::mindspore::irpb::Summary_Image >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Summary_Image >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Summary_Histogram_bucket* Arena::CreateMaybeMessage< ::mindspore::irpb::Summary_Histogram_bucket >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Summary_Histogram_bucket >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Summary_Histogram* Arena::CreateMaybeMessage< ::mindspore::irpb::Summary_Histogram >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Summary_Histogram >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Summary_Value* Arena::CreateMaybeMessage< ::mindspore::irpb::Summary_Value >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Summary_Value >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Summary* Arena::CreateMaybeMessage< ::mindspore::irpb::Summary >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Summary >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain_Inference* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain_Inference >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain_Inference >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain_Explanation* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain_Explanation >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain_Explanation >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain_Benchmark* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain_Benchmark >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain_Benchmark >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain_Metadata* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain_Metadata >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain_Metadata >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain_HocLayer* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain_HocLayer >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain_HocLayer >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain_Hoc* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain_Hoc >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain_Hoc >(arena); +} +template<> PROTOBUF_NOINLINE ::mindspore::irpb::Explain* Arena::CreateMaybeMessage< ::mindspore::irpb::Explain >(Arena* arena) { + return Arena::CreateMessageInternal< ::mindspore::irpb::Explain >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/mindspore_summary.pb.h b/plugins/mindstudio-insight-plugins/proto/mindspore_summary.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..7d073eded95d2904e08fae76a7998e5348d8ccec --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/mindspore_summary.pb.h @@ -0,0 +1,7989 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mindspore_summary.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_mindspore_5fsummary_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_mindspore_5fsummary_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include "mindspore_anf_ir.pb.h" +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_mindspore_5fsummary_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_mindspore_5fsummary_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[17] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_mindspore_5fsummary_2eproto; +namespace mindspore { +namespace irpb { +class Event; +class EventDefaultTypeInternal; +extern EventDefaultTypeInternal _Event_default_instance_; +class Explain; +class ExplainDefaultTypeInternal; +extern ExplainDefaultTypeInternal _Explain_default_instance_; +class Explain_Benchmark; +class Explain_BenchmarkDefaultTypeInternal; +extern Explain_BenchmarkDefaultTypeInternal _Explain_Benchmark_default_instance_; +class Explain_Explanation; +class Explain_ExplanationDefaultTypeInternal; +extern Explain_ExplanationDefaultTypeInternal _Explain_Explanation_default_instance_; +class Explain_Hoc; +class Explain_HocDefaultTypeInternal; +extern Explain_HocDefaultTypeInternal _Explain_Hoc_default_instance_; +class Explain_HocLayer; +class Explain_HocLayerDefaultTypeInternal; +extern Explain_HocLayerDefaultTypeInternal _Explain_HocLayer_default_instance_; +class Explain_Inference; +class Explain_InferenceDefaultTypeInternal; +extern Explain_InferenceDefaultTypeInternal _Explain_Inference_default_instance_; +class Explain_Metadata; +class Explain_MetadataDefaultTypeInternal; +extern Explain_MetadataDefaultTypeInternal _Explain_Metadata_default_instance_; +class LossLandscape; +class LossLandscapeDefaultTypeInternal; +extern LossLandscapeDefaultTypeInternal _LossLandscape_default_instance_; +class LossLandscape_LossPath; +class LossLandscape_LossPathDefaultTypeInternal; +extern LossLandscape_LossPathDefaultTypeInternal _LossLandscape_LossPath_default_instance_; +class LossLandscape_Metadata; +class LossLandscape_MetadataDefaultTypeInternal; +extern LossLandscape_MetadataDefaultTypeInternal _LossLandscape_Metadata_default_instance_; +class LossLandscape_Point; +class LossLandscape_PointDefaultTypeInternal; +extern LossLandscape_PointDefaultTypeInternal _LossLandscape_Point_default_instance_; +class Summary; +class SummaryDefaultTypeInternal; +extern SummaryDefaultTypeInternal _Summary_default_instance_; +class Summary_Histogram; +class Summary_HistogramDefaultTypeInternal; +extern Summary_HistogramDefaultTypeInternal _Summary_Histogram_default_instance_; +class Summary_Histogram_bucket; +class Summary_Histogram_bucketDefaultTypeInternal; +extern Summary_Histogram_bucketDefaultTypeInternal _Summary_Histogram_bucket_default_instance_; +class Summary_Image; +class Summary_ImageDefaultTypeInternal; +extern Summary_ImageDefaultTypeInternal _Summary_Image_default_instance_; +class Summary_Value; +class Summary_ValueDefaultTypeInternal; +extern Summary_ValueDefaultTypeInternal _Summary_Value_default_instance_; +} // namespace irpb +} // namespace mindspore +PROTOBUF_NAMESPACE_OPEN +template<> ::mindspore::irpb::Event* Arena::CreateMaybeMessage<::mindspore::irpb::Event>(Arena*); +template<> ::mindspore::irpb::Explain* Arena::CreateMaybeMessage<::mindspore::irpb::Explain>(Arena*); +template<> ::mindspore::irpb::Explain_Benchmark* Arena::CreateMaybeMessage<::mindspore::irpb::Explain_Benchmark>(Arena*); +template<> ::mindspore::irpb::Explain_Explanation* Arena::CreateMaybeMessage<::mindspore::irpb::Explain_Explanation>(Arena*); +template<> ::mindspore::irpb::Explain_Hoc* Arena::CreateMaybeMessage<::mindspore::irpb::Explain_Hoc>(Arena*); +template<> ::mindspore::irpb::Explain_HocLayer* Arena::CreateMaybeMessage<::mindspore::irpb::Explain_HocLayer>(Arena*); +template<> ::mindspore::irpb::Explain_Inference* Arena::CreateMaybeMessage<::mindspore::irpb::Explain_Inference>(Arena*); +template<> ::mindspore::irpb::Explain_Metadata* Arena::CreateMaybeMessage<::mindspore::irpb::Explain_Metadata>(Arena*); +template<> ::mindspore::irpb::LossLandscape* Arena::CreateMaybeMessage<::mindspore::irpb::LossLandscape>(Arena*); +template<> ::mindspore::irpb::LossLandscape_LossPath* Arena::CreateMaybeMessage<::mindspore::irpb::LossLandscape_LossPath>(Arena*); +template<> ::mindspore::irpb::LossLandscape_Metadata* Arena::CreateMaybeMessage<::mindspore::irpb::LossLandscape_Metadata>(Arena*); +template<> ::mindspore::irpb::LossLandscape_Point* Arena::CreateMaybeMessage<::mindspore::irpb::LossLandscape_Point>(Arena*); +template<> ::mindspore::irpb::Summary* Arena::CreateMaybeMessage<::mindspore::irpb::Summary>(Arena*); +template<> ::mindspore::irpb::Summary_Histogram* Arena::CreateMaybeMessage<::mindspore::irpb::Summary_Histogram>(Arena*); +template<> ::mindspore::irpb::Summary_Histogram_bucket* Arena::CreateMaybeMessage<::mindspore::irpb::Summary_Histogram_bucket>(Arena*); +template<> ::mindspore::irpb::Summary_Image* Arena::CreateMaybeMessage<::mindspore::irpb::Summary_Image>(Arena*); +template<> ::mindspore::irpb::Summary_Value* Arena::CreateMaybeMessage<::mindspore::irpb::Summary_Value>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace mindspore { +namespace irpb { + +// =================================================================== + +class Event PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Event) */ { + public: + inline Event() : Event(nullptr) {} + virtual ~Event(); + + Event(const Event& from); + Event(Event&& from) noexcept + : Event() { + *this = ::std::move(from); + } + + inline Event& operator=(const Event& from) { + CopyFrom(from); + return *this; + } + inline Event& operator=(Event&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Event& default_instance(); + + enum WhatCase { + kVersion = 3, + kGraphDef = 4, + kSummary = 5, + kExplain = 6, + WHAT_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Event* internal_default_instance() { + return reinterpret_cast( + &_Event_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(Event& a, Event& b) { + a.Swap(&b); + } + inline void Swap(Event* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Event* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Event* New() const final { + return CreateMaybeMessage(nullptr); + } + + Event* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Event& from); + void MergeFrom(const Event& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Event* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Event"; + } + protected: + explicit Event(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kWallTimeFieldNumber = 1, + kStepFieldNumber = 2, + kVersionFieldNumber = 3, + kGraphDefFieldNumber = 4, + kSummaryFieldNumber = 5, + kExplainFieldNumber = 6, + }; + // required double wall_time = 1; + bool has_wall_time() const; + private: + bool _internal_has_wall_time() const; + public: + void clear_wall_time(); + double wall_time() const; + void set_wall_time(double value); + private: + double _internal_wall_time() const; + void _internal_set_wall_time(double value); + public: + + // optional int64 step = 2; + bool has_step() const; + private: + bool _internal_has_step() const; + public: + void clear_step(); + ::PROTOBUF_NAMESPACE_ID::int64 step() const; + void set_step(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_step() const; + void _internal_set_step(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // string version = 3; + bool has_version() const; + private: + bool _internal_has_version() const; + public: + void clear_version(); + const std::string& version() const; + void set_version(const std::string& value); + void set_version(std::string&& value); + void set_version(const char* value); + void set_version(const char* value, size_t size); + std::string* mutable_version(); + std::string* release_version(); + void set_allocated_version(std::string* version); + private: + const std::string& _internal_version() const; + void _internal_set_version(const std::string& value); + std::string* _internal_mutable_version(); + public: + + // .mindspore.irpb.GraphProto graph_def = 4; + bool has_graph_def() const; + private: + bool _internal_has_graph_def() const; + public: + void clear_graph_def(); + const ::mindspore::irpb::GraphProto& graph_def() const; + ::mindspore::irpb::GraphProto* release_graph_def(); + ::mindspore::irpb::GraphProto* mutable_graph_def(); + void set_allocated_graph_def(::mindspore::irpb::GraphProto* graph_def); + private: + const ::mindspore::irpb::GraphProto& _internal_graph_def() const; + ::mindspore::irpb::GraphProto* _internal_mutable_graph_def(); + public: + void unsafe_arena_set_allocated_graph_def( + ::mindspore::irpb::GraphProto* graph_def); + ::mindspore::irpb::GraphProto* unsafe_arena_release_graph_def(); + + // .mindspore.irpb.Summary summary = 5; + bool has_summary() const; + private: + bool _internal_has_summary() const; + public: + void clear_summary(); + const ::mindspore::irpb::Summary& summary() const; + ::mindspore::irpb::Summary* release_summary(); + ::mindspore::irpb::Summary* mutable_summary(); + void set_allocated_summary(::mindspore::irpb::Summary* summary); + private: + const ::mindspore::irpb::Summary& _internal_summary() const; + ::mindspore::irpb::Summary* _internal_mutable_summary(); + public: + void unsafe_arena_set_allocated_summary( + ::mindspore::irpb::Summary* summary); + ::mindspore::irpb::Summary* unsafe_arena_release_summary(); + + // .mindspore.irpb.Explain explain = 6; + bool has_explain() const; + private: + bool _internal_has_explain() const; + public: + void clear_explain(); + const ::mindspore::irpb::Explain& explain() const; + ::mindspore::irpb::Explain* release_explain(); + ::mindspore::irpb::Explain* mutable_explain(); + void set_allocated_explain(::mindspore::irpb::Explain* explain); + private: + const ::mindspore::irpb::Explain& _internal_explain() const; + ::mindspore::irpb::Explain* _internal_mutable_explain(); + public: + void unsafe_arena_set_allocated_explain( + ::mindspore::irpb::Explain* explain); + ::mindspore::irpb::Explain* unsafe_arena_release_explain(); + + void clear_what(); + WhatCase what_case() const; + // @@protoc_insertion_point(class_scope:mindspore.irpb.Event) + private: + class _Internal; + void set_has_version(); + void set_has_graph_def(); + void set_has_summary(); + void set_has_explain(); + + inline bool has_what() const; + inline void clear_has_what(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + double wall_time_; + ::PROTOBUF_NAMESPACE_ID::int64 step_; + union WhatUnion { + WhatUnion() {} + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr version_; + ::mindspore::irpb::GraphProto* graph_def_; + ::mindspore::irpb::Summary* summary_; + ::mindspore::irpb::Explain* explain_; + } what_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class LossLandscape_Point PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.LossLandscape.Point) */ { + public: + inline LossLandscape_Point() : LossLandscape_Point(nullptr) {} + virtual ~LossLandscape_Point(); + + LossLandscape_Point(const LossLandscape_Point& from); + LossLandscape_Point(LossLandscape_Point&& from) noexcept + : LossLandscape_Point() { + *this = ::std::move(from); + } + + inline LossLandscape_Point& operator=(const LossLandscape_Point& from) { + CopyFrom(from); + return *this; + } + inline LossLandscape_Point& operator=(LossLandscape_Point&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const LossLandscape_Point& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const LossLandscape_Point* internal_default_instance() { + return reinterpret_cast( + &_LossLandscape_Point_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(LossLandscape_Point& a, LossLandscape_Point& b) { + a.Swap(&b); + } + inline void Swap(LossLandscape_Point* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(LossLandscape_Point* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline LossLandscape_Point* New() const final { + return CreateMaybeMessage(nullptr); + } + + LossLandscape_Point* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const LossLandscape_Point& from); + void MergeFrom(const LossLandscape_Point& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(LossLandscape_Point* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.LossLandscape.Point"; + } + protected: + explicit LossLandscape_Point(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kXFieldNumber = 1, + kYFieldNumber = 2, + kZFieldNumber = 3, + }; + // optional .mindspore.irpb.TensorProto x = 1; + bool has_x() const; + private: + bool _internal_has_x() const; + public: + void clear_x(); + const ::mindspore::irpb::TensorProto& x() const; + ::mindspore::irpb::TensorProto* release_x(); + ::mindspore::irpb::TensorProto* mutable_x(); + void set_allocated_x(::mindspore::irpb::TensorProto* x); + private: + const ::mindspore::irpb::TensorProto& _internal_x() const; + ::mindspore::irpb::TensorProto* _internal_mutable_x(); + public: + void unsafe_arena_set_allocated_x( + ::mindspore::irpb::TensorProto* x); + ::mindspore::irpb::TensorProto* unsafe_arena_release_x(); + + // optional .mindspore.irpb.TensorProto y = 2; + bool has_y() const; + private: + bool _internal_has_y() const; + public: + void clear_y(); + const ::mindspore::irpb::TensorProto& y() const; + ::mindspore::irpb::TensorProto* release_y(); + ::mindspore::irpb::TensorProto* mutable_y(); + void set_allocated_y(::mindspore::irpb::TensorProto* y); + private: + const ::mindspore::irpb::TensorProto& _internal_y() const; + ::mindspore::irpb::TensorProto* _internal_mutable_y(); + public: + void unsafe_arena_set_allocated_y( + ::mindspore::irpb::TensorProto* y); + ::mindspore::irpb::TensorProto* unsafe_arena_release_y(); + + // optional .mindspore.irpb.TensorProto z = 3; + bool has_z() const; + private: + bool _internal_has_z() const; + public: + void clear_z(); + const ::mindspore::irpb::TensorProto& z() const; + ::mindspore::irpb::TensorProto* release_z(); + ::mindspore::irpb::TensorProto* mutable_z(); + void set_allocated_z(::mindspore::irpb::TensorProto* z); + private: + const ::mindspore::irpb::TensorProto& _internal_z() const; + ::mindspore::irpb::TensorProto* _internal_mutable_z(); + public: + void unsafe_arena_set_allocated_z( + ::mindspore::irpb::TensorProto* z); + ::mindspore::irpb::TensorProto* unsafe_arena_release_z(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.LossLandscape.Point) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::mindspore::irpb::TensorProto* x_; + ::mindspore::irpb::TensorProto* y_; + ::mindspore::irpb::TensorProto* z_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class LossLandscape_LossPath PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.LossLandscape.LossPath) */ { + public: + inline LossLandscape_LossPath() : LossLandscape_LossPath(nullptr) {} + virtual ~LossLandscape_LossPath(); + + LossLandscape_LossPath(const LossLandscape_LossPath& from); + LossLandscape_LossPath(LossLandscape_LossPath&& from) noexcept + : LossLandscape_LossPath() { + *this = ::std::move(from); + } + + inline LossLandscape_LossPath& operator=(const LossLandscape_LossPath& from) { + CopyFrom(from); + return *this; + } + inline LossLandscape_LossPath& operator=(LossLandscape_LossPath&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const LossLandscape_LossPath& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const LossLandscape_LossPath* internal_default_instance() { + return reinterpret_cast( + &_LossLandscape_LossPath_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(LossLandscape_LossPath& a, LossLandscape_LossPath& b) { + a.Swap(&b); + } + inline void Swap(LossLandscape_LossPath* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(LossLandscape_LossPath* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline LossLandscape_LossPath* New() const final { + return CreateMaybeMessage(nullptr); + } + + LossLandscape_LossPath* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const LossLandscape_LossPath& from); + void MergeFrom(const LossLandscape_LossPath& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(LossLandscape_LossPath* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.LossLandscape.LossPath"; + } + protected: + explicit LossLandscape_LossPath(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kIntervalsFieldNumber = 1, + kPointsFieldNumber = 2, + }; + // repeated int32 intervals = 1; + int intervals_size() const; + private: + int _internal_intervals_size() const; + public: + void clear_intervals(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_intervals(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_intervals() const; + void _internal_add_intervals(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_intervals(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 intervals(int index) const; + void set_intervals(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_intervals(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + intervals() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_intervals(); + + // optional .mindspore.irpb.LossLandscape.Point points = 2; + bool has_points() const; + private: + bool _internal_has_points() const; + public: + void clear_points(); + const ::mindspore::irpb::LossLandscape_Point& points() const; + ::mindspore::irpb::LossLandscape_Point* release_points(); + ::mindspore::irpb::LossLandscape_Point* mutable_points(); + void set_allocated_points(::mindspore::irpb::LossLandscape_Point* points); + private: + const ::mindspore::irpb::LossLandscape_Point& _internal_points() const; + ::mindspore::irpb::LossLandscape_Point* _internal_mutable_points(); + public: + void unsafe_arena_set_allocated_points( + ::mindspore::irpb::LossLandscape_Point* points); + ::mindspore::irpb::LossLandscape_Point* unsafe_arena_release_points(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.LossLandscape.LossPath) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > intervals_; + ::mindspore::irpb::LossLandscape_Point* points_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class LossLandscape_Metadata PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.LossLandscape.Metadata) */ { + public: + inline LossLandscape_Metadata() : LossLandscape_Metadata(nullptr) {} + virtual ~LossLandscape_Metadata(); + + LossLandscape_Metadata(const LossLandscape_Metadata& from); + LossLandscape_Metadata(LossLandscape_Metadata&& from) noexcept + : LossLandscape_Metadata() { + *this = ::std::move(from); + } + + inline LossLandscape_Metadata& operator=(const LossLandscape_Metadata& from) { + CopyFrom(from); + return *this; + } + inline LossLandscape_Metadata& operator=(LossLandscape_Metadata&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const LossLandscape_Metadata& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const LossLandscape_Metadata* internal_default_instance() { + return reinterpret_cast( + &_LossLandscape_Metadata_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(LossLandscape_Metadata& a, LossLandscape_Metadata& b) { + a.Swap(&b); + } + inline void Swap(LossLandscape_Metadata* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(LossLandscape_Metadata* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline LossLandscape_Metadata* New() const final { + return CreateMaybeMessage(nullptr); + } + + LossLandscape_Metadata* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const LossLandscape_Metadata& from); + void MergeFrom(const LossLandscape_Metadata& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(LossLandscape_Metadata* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.LossLandscape.Metadata"; + } + protected: + explicit LossLandscape_Metadata(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDecompositionFieldNumber = 1, + kUnitFieldNumber = 2, + kStepPerEpochFieldNumber = 3, + }; + // optional string decomposition = 1; + bool has_decomposition() const; + private: + bool _internal_has_decomposition() const; + public: + void clear_decomposition(); + const std::string& decomposition() const; + void set_decomposition(const std::string& value); + void set_decomposition(std::string&& value); + void set_decomposition(const char* value); + void set_decomposition(const char* value, size_t size); + std::string* mutable_decomposition(); + std::string* release_decomposition(); + void set_allocated_decomposition(std::string* decomposition); + private: + const std::string& _internal_decomposition() const; + void _internal_set_decomposition(const std::string& value); + std::string* _internal_mutable_decomposition(); + public: + + // optional string unit = 2; + bool has_unit() const; + private: + bool _internal_has_unit() const; + public: + void clear_unit(); + const std::string& unit() const; + void set_unit(const std::string& value); + void set_unit(std::string&& value); + void set_unit(const char* value); + void set_unit(const char* value, size_t size); + std::string* mutable_unit(); + std::string* release_unit(); + void set_allocated_unit(std::string* unit); + private: + const std::string& _internal_unit() const; + void _internal_set_unit(const std::string& value); + std::string* _internal_mutable_unit(); + public: + + // optional int32 step_per_epoch = 3; + bool has_step_per_epoch() const; + private: + bool _internal_has_step_per_epoch() const; + public: + void clear_step_per_epoch(); + ::PROTOBUF_NAMESPACE_ID::int32 step_per_epoch() const; + void set_step_per_epoch(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_step_per_epoch() const; + void _internal_set_step_per_epoch(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.LossLandscape.Metadata) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr decomposition_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr unit_; + ::PROTOBUF_NAMESPACE_ID::int32 step_per_epoch_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class LossLandscape PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.LossLandscape) */ { + public: + inline LossLandscape() : LossLandscape(nullptr) {} + virtual ~LossLandscape(); + + LossLandscape(const LossLandscape& from); + LossLandscape(LossLandscape&& from) noexcept + : LossLandscape() { + *this = ::std::move(from); + } + + inline LossLandscape& operator=(const LossLandscape& from) { + CopyFrom(from); + return *this; + } + inline LossLandscape& operator=(LossLandscape&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const LossLandscape& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const LossLandscape* internal_default_instance() { + return reinterpret_cast( + &_LossLandscape_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(LossLandscape& a, LossLandscape& b) { + a.Swap(&b); + } + inline void Swap(LossLandscape* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(LossLandscape* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline LossLandscape* New() const final { + return CreateMaybeMessage(nullptr); + } + + LossLandscape* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const LossLandscape& from); + void MergeFrom(const LossLandscape& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(LossLandscape* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.LossLandscape"; + } + protected: + explicit LossLandscape(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef LossLandscape_Point Point; + typedef LossLandscape_LossPath LossPath; + typedef LossLandscape_Metadata Metadata; + + // accessors ------------------------------------------------------- + + enum : int { + kLandscapeFieldNumber = 1, + kLossPathFieldNumber = 2, + kMetadataFieldNumber = 3, + kConvergencePointFieldNumber = 4, + }; + // optional .mindspore.irpb.LossLandscape.Point landscape = 1; + bool has_landscape() const; + private: + bool _internal_has_landscape() const; + public: + void clear_landscape(); + const ::mindspore::irpb::LossLandscape_Point& landscape() const; + ::mindspore::irpb::LossLandscape_Point* release_landscape(); + ::mindspore::irpb::LossLandscape_Point* mutable_landscape(); + void set_allocated_landscape(::mindspore::irpb::LossLandscape_Point* landscape); + private: + const ::mindspore::irpb::LossLandscape_Point& _internal_landscape() const; + ::mindspore::irpb::LossLandscape_Point* _internal_mutable_landscape(); + public: + void unsafe_arena_set_allocated_landscape( + ::mindspore::irpb::LossLandscape_Point* landscape); + ::mindspore::irpb::LossLandscape_Point* unsafe_arena_release_landscape(); + + // optional .mindspore.irpb.LossLandscape.LossPath loss_path = 2; + bool has_loss_path() const; + private: + bool _internal_has_loss_path() const; + public: + void clear_loss_path(); + const ::mindspore::irpb::LossLandscape_LossPath& loss_path() const; + ::mindspore::irpb::LossLandscape_LossPath* release_loss_path(); + ::mindspore::irpb::LossLandscape_LossPath* mutable_loss_path(); + void set_allocated_loss_path(::mindspore::irpb::LossLandscape_LossPath* loss_path); + private: + const ::mindspore::irpb::LossLandscape_LossPath& _internal_loss_path() const; + ::mindspore::irpb::LossLandscape_LossPath* _internal_mutable_loss_path(); + public: + void unsafe_arena_set_allocated_loss_path( + ::mindspore::irpb::LossLandscape_LossPath* loss_path); + ::mindspore::irpb::LossLandscape_LossPath* unsafe_arena_release_loss_path(); + + // optional .mindspore.irpb.LossLandscape.Metadata metadata = 3; + bool has_metadata() const; + private: + bool _internal_has_metadata() const; + public: + void clear_metadata(); + const ::mindspore::irpb::LossLandscape_Metadata& metadata() const; + ::mindspore::irpb::LossLandscape_Metadata* release_metadata(); + ::mindspore::irpb::LossLandscape_Metadata* mutable_metadata(); + void set_allocated_metadata(::mindspore::irpb::LossLandscape_Metadata* metadata); + private: + const ::mindspore::irpb::LossLandscape_Metadata& _internal_metadata() const; + ::mindspore::irpb::LossLandscape_Metadata* _internal_mutable_metadata(); + public: + void unsafe_arena_set_allocated_metadata( + ::mindspore::irpb::LossLandscape_Metadata* metadata); + ::mindspore::irpb::LossLandscape_Metadata* unsafe_arena_release_metadata(); + + // optional .mindspore.irpb.LossLandscape.Point convergence_point = 4; + bool has_convergence_point() const; + private: + bool _internal_has_convergence_point() const; + public: + void clear_convergence_point(); + const ::mindspore::irpb::LossLandscape_Point& convergence_point() const; + ::mindspore::irpb::LossLandscape_Point* release_convergence_point(); + ::mindspore::irpb::LossLandscape_Point* mutable_convergence_point(); + void set_allocated_convergence_point(::mindspore::irpb::LossLandscape_Point* convergence_point); + private: + const ::mindspore::irpb::LossLandscape_Point& _internal_convergence_point() const; + ::mindspore::irpb::LossLandscape_Point* _internal_mutable_convergence_point(); + public: + void unsafe_arena_set_allocated_convergence_point( + ::mindspore::irpb::LossLandscape_Point* convergence_point); + ::mindspore::irpb::LossLandscape_Point* unsafe_arena_release_convergence_point(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.LossLandscape) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::mindspore::irpb::LossLandscape_Point* landscape_; + ::mindspore::irpb::LossLandscape_LossPath* loss_path_; + ::mindspore::irpb::LossLandscape_Metadata* metadata_; + ::mindspore::irpb::LossLandscape_Point* convergence_point_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Image PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Summary.Image) */ { + public: + inline Summary_Image() : Summary_Image(nullptr) {} + virtual ~Summary_Image(); + + Summary_Image(const Summary_Image& from); + Summary_Image(Summary_Image&& from) noexcept + : Summary_Image() { + *this = ::std::move(from); + } + + inline Summary_Image& operator=(const Summary_Image& from) { + CopyFrom(from); + return *this; + } + inline Summary_Image& operator=(Summary_Image&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Image& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Image* internal_default_instance() { + return reinterpret_cast( + &_Summary_Image_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(Summary_Image& a, Summary_Image& b) { + a.Swap(&b); + } + inline void Swap(Summary_Image* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Image* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Image* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Image* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Image& from); + void MergeFrom(const Summary_Image& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Image* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Summary.Image"; + } + protected: + explicit Summary_Image(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kEncodedImageFieldNumber = 4, + kHeightFieldNumber = 1, + kWidthFieldNumber = 2, + kColorspaceFieldNumber = 3, + }; + // required bytes encoded_image = 4; + bool has_encoded_image() const; + private: + bool _internal_has_encoded_image() const; + public: + void clear_encoded_image(); + const std::string& encoded_image() const; + void set_encoded_image(const std::string& value); + void set_encoded_image(std::string&& value); + void set_encoded_image(const char* value); + void set_encoded_image(const void* value, size_t size); + std::string* mutable_encoded_image(); + std::string* release_encoded_image(); + void set_allocated_encoded_image(std::string* encoded_image); + private: + const std::string& _internal_encoded_image() const; + void _internal_set_encoded_image(const std::string& value); + std::string* _internal_mutable_encoded_image(); + public: + + // required int32 height = 1; + bool has_height() const; + private: + bool _internal_has_height() const; + public: + void clear_height(); + ::PROTOBUF_NAMESPACE_ID::int32 height() const; + void set_height(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_height() const; + void _internal_set_height(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // required int32 width = 2; + bool has_width() const; + private: + bool _internal_has_width() const; + public: + void clear_width(); + ::PROTOBUF_NAMESPACE_ID::int32 width() const; + void set_width(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_width() const; + void _internal_set_width(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // required int32 colorspace = 3; + bool has_colorspace() const; + private: + bool _internal_has_colorspace() const; + public: + void clear_colorspace(); + ::PROTOBUF_NAMESPACE_ID::int32 colorspace() const; + void set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_colorspace() const; + void _internal_set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Summary.Image) + private: + class _Internal; + + // helper for ByteSizeLong() + size_t RequiredFieldsByteSizeFallback() const; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr encoded_image_; + ::PROTOBUF_NAMESPACE_ID::int32 height_; + ::PROTOBUF_NAMESPACE_ID::int32 width_; + ::PROTOBUF_NAMESPACE_ID::int32 colorspace_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Histogram_bucket PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Summary.Histogram.bucket) */ { + public: + inline Summary_Histogram_bucket() : Summary_Histogram_bucket(nullptr) {} + virtual ~Summary_Histogram_bucket(); + + Summary_Histogram_bucket(const Summary_Histogram_bucket& from); + Summary_Histogram_bucket(Summary_Histogram_bucket&& from) noexcept + : Summary_Histogram_bucket() { + *this = ::std::move(from); + } + + inline Summary_Histogram_bucket& operator=(const Summary_Histogram_bucket& from) { + CopyFrom(from); + return *this; + } + inline Summary_Histogram_bucket& operator=(Summary_Histogram_bucket&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Histogram_bucket& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Histogram_bucket* internal_default_instance() { + return reinterpret_cast( + &_Summary_Histogram_bucket_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(Summary_Histogram_bucket& a, Summary_Histogram_bucket& b) { + a.Swap(&b); + } + inline void Swap(Summary_Histogram_bucket* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Histogram_bucket* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Histogram_bucket* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Histogram_bucket* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Histogram_bucket& from); + void MergeFrom(const Summary_Histogram_bucket& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Histogram_bucket* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Summary.Histogram.bucket"; + } + protected: + explicit Summary_Histogram_bucket(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kLeftFieldNumber = 1, + kWidthFieldNumber = 2, + kCountFieldNumber = 3, + }; + // required double left = 1; + bool has_left() const; + private: + bool _internal_has_left() const; + public: + void clear_left(); + double left() const; + void set_left(double value); + private: + double _internal_left() const; + void _internal_set_left(double value); + public: + + // required double width = 2; + bool has_width() const; + private: + bool _internal_has_width() const; + public: + void clear_width(); + double width() const; + void set_width(double value); + private: + double _internal_width() const; + void _internal_set_width(double value); + public: + + // required int64 count = 3; + bool has_count() const; + private: + bool _internal_has_count() const; + public: + void clear_count(); + ::PROTOBUF_NAMESPACE_ID::int64 count() const; + void set_count(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_count() const; + void _internal_set_count(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Summary.Histogram.bucket) + private: + class _Internal; + + // helper for ByteSizeLong() + size_t RequiredFieldsByteSizeFallback() const; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + double left_; + double width_; + ::PROTOBUF_NAMESPACE_ID::int64 count_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Histogram PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Summary.Histogram) */ { + public: + inline Summary_Histogram() : Summary_Histogram(nullptr) {} + virtual ~Summary_Histogram(); + + Summary_Histogram(const Summary_Histogram& from); + Summary_Histogram(Summary_Histogram&& from) noexcept + : Summary_Histogram() { + *this = ::std::move(from); + } + + inline Summary_Histogram& operator=(const Summary_Histogram& from) { + CopyFrom(from); + return *this; + } + inline Summary_Histogram& operator=(Summary_Histogram&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Histogram& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Histogram* internal_default_instance() { + return reinterpret_cast( + &_Summary_Histogram_default_instance_); + } + static constexpr int kIndexInFileMessages = + 7; + + friend void swap(Summary_Histogram& a, Summary_Histogram& b) { + a.Swap(&b); + } + inline void Swap(Summary_Histogram* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Histogram* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Histogram* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Histogram* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Histogram& from); + void MergeFrom(const Summary_Histogram& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Histogram* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Summary.Histogram"; + } + protected: + explicit Summary_Histogram(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef Summary_Histogram_bucket bucket; + + // accessors ------------------------------------------------------- + + enum : int { + kBucketsFieldNumber = 1, + kNanCountFieldNumber = 2, + kPosInfCountFieldNumber = 3, + kNegInfCountFieldNumber = 4, + kMaxFieldNumber = 5, + kMinFieldNumber = 6, + kSumFieldNumber = 7, + kCountFieldNumber = 8, + }; + // repeated .mindspore.irpb.Summary.Histogram.bucket buckets = 1; + int buckets_size() const; + private: + int _internal_buckets_size() const; + public: + void clear_buckets(); + ::mindspore::irpb::Summary_Histogram_bucket* mutable_buckets(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Histogram_bucket >* + mutable_buckets(); + private: + const ::mindspore::irpb::Summary_Histogram_bucket& _internal_buckets(int index) const; + ::mindspore::irpb::Summary_Histogram_bucket* _internal_add_buckets(); + public: + const ::mindspore::irpb::Summary_Histogram_bucket& buckets(int index) const; + ::mindspore::irpb::Summary_Histogram_bucket* add_buckets(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Histogram_bucket >& + buckets() const; + + // optional int64 nan_count = 2; + bool has_nan_count() const; + private: + bool _internal_has_nan_count() const; + public: + void clear_nan_count(); + ::PROTOBUF_NAMESPACE_ID::int64 nan_count() const; + void set_nan_count(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_nan_count() const; + void _internal_set_nan_count(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 pos_inf_count = 3; + bool has_pos_inf_count() const; + private: + bool _internal_has_pos_inf_count() const; + public: + void clear_pos_inf_count(); + ::PROTOBUF_NAMESPACE_ID::int64 pos_inf_count() const; + void set_pos_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_pos_inf_count() const; + void _internal_set_pos_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 neg_inf_count = 4; + bool has_neg_inf_count() const; + private: + bool _internal_has_neg_inf_count() const; + public: + void clear_neg_inf_count(); + ::PROTOBUF_NAMESPACE_ID::int64 neg_inf_count() const; + void set_neg_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_neg_inf_count() const; + void _internal_set_neg_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional double max = 5; + bool has_max() const; + private: + bool _internal_has_max() const; + public: + void clear_max(); + double max() const; + void set_max(double value); + private: + double _internal_max() const; + void _internal_set_max(double value); + public: + + // optional double min = 6; + bool has_min() const; + private: + bool _internal_has_min() const; + public: + void clear_min(); + double min() const; + void set_min(double value); + private: + double _internal_min() const; + void _internal_set_min(double value); + public: + + // optional double sum = 7; + bool has_sum() const; + private: + bool _internal_has_sum() const; + public: + void clear_sum(); + double sum() const; + void set_sum(double value); + private: + double _internal_sum() const; + void _internal_set_sum(double value); + public: + + // optional int64 count = 8; + bool has_count() const; + private: + bool _internal_has_count() const; + public: + void clear_count(); + ::PROTOBUF_NAMESPACE_ID::int64 count() const; + void set_count(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_count() const; + void _internal_set_count(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Summary.Histogram) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Histogram_bucket > buckets_; + ::PROTOBUF_NAMESPACE_ID::int64 nan_count_; + ::PROTOBUF_NAMESPACE_ID::int64 pos_inf_count_; + ::PROTOBUF_NAMESPACE_ID::int64 neg_inf_count_; + double max_; + double min_; + double sum_; + ::PROTOBUF_NAMESPACE_ID::int64 count_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Value PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Summary.Value) */ { + public: + inline Summary_Value() : Summary_Value(nullptr) {} + virtual ~Summary_Value(); + + Summary_Value(const Summary_Value& from); + Summary_Value(Summary_Value&& from) noexcept + : Summary_Value() { + *this = ::std::move(from); + } + + inline Summary_Value& operator=(const Summary_Value& from) { + CopyFrom(from); + return *this; + } + inline Summary_Value& operator=(Summary_Value&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Value& default_instance(); + + enum ValueCase { + kScalarValue = 3, + kImage = 4, + kTensor = 8, + kHistogram = 9, + kLossLandscape = 10, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Value* internal_default_instance() { + return reinterpret_cast( + &_Summary_Value_default_instance_); + } + static constexpr int kIndexInFileMessages = + 8; + + friend void swap(Summary_Value& a, Summary_Value& b) { + a.Swap(&b); + } + inline void Swap(Summary_Value* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Value* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Value* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Value* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Value& from); + void MergeFrom(const Summary_Value& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Value* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Summary.Value"; + } + protected: + explicit Summary_Value(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTagFieldNumber = 1, + kScalarValueFieldNumber = 3, + kImageFieldNumber = 4, + kTensorFieldNumber = 8, + kHistogramFieldNumber = 9, + kLossLandscapeFieldNumber = 10, + }; + // required string tag = 1; + bool has_tag() const; + private: + bool _internal_has_tag() const; + public: + void clear_tag(); + const std::string& tag() const; + void set_tag(const std::string& value); + void set_tag(std::string&& value); + void set_tag(const char* value); + void set_tag(const char* value, size_t size); + std::string* mutable_tag(); + std::string* release_tag(); + void set_allocated_tag(std::string* tag); + private: + const std::string& _internal_tag() const; + void _internal_set_tag(const std::string& value); + std::string* _internal_mutable_tag(); + public: + + // float scalar_value = 3; + bool has_scalar_value() const; + private: + bool _internal_has_scalar_value() const; + public: + void clear_scalar_value(); + float scalar_value() const; + void set_scalar_value(float value); + private: + float _internal_scalar_value() const; + void _internal_set_scalar_value(float value); + public: + + // .mindspore.irpb.Summary.Image image = 4; + bool has_image() const; + private: + bool _internal_has_image() const; + public: + void clear_image(); + const ::mindspore::irpb::Summary_Image& image() const; + ::mindspore::irpb::Summary_Image* release_image(); + ::mindspore::irpb::Summary_Image* mutable_image(); + void set_allocated_image(::mindspore::irpb::Summary_Image* image); + private: + const ::mindspore::irpb::Summary_Image& _internal_image() const; + ::mindspore::irpb::Summary_Image* _internal_mutable_image(); + public: + void unsafe_arena_set_allocated_image( + ::mindspore::irpb::Summary_Image* image); + ::mindspore::irpb::Summary_Image* unsafe_arena_release_image(); + + // .mindspore.irpb.TensorProto tensor = 8; + bool has_tensor() const; + private: + bool _internal_has_tensor() const; + public: + void clear_tensor(); + const ::mindspore::irpb::TensorProto& tensor() const; + ::mindspore::irpb::TensorProto* release_tensor(); + ::mindspore::irpb::TensorProto* mutable_tensor(); + void set_allocated_tensor(::mindspore::irpb::TensorProto* tensor); + private: + const ::mindspore::irpb::TensorProto& _internal_tensor() const; + ::mindspore::irpb::TensorProto* _internal_mutable_tensor(); + public: + void unsafe_arena_set_allocated_tensor( + ::mindspore::irpb::TensorProto* tensor); + ::mindspore::irpb::TensorProto* unsafe_arena_release_tensor(); + + // .mindspore.irpb.Summary.Histogram histogram = 9; + bool has_histogram() const; + private: + bool _internal_has_histogram() const; + public: + void clear_histogram(); + const ::mindspore::irpb::Summary_Histogram& histogram() const; + ::mindspore::irpb::Summary_Histogram* release_histogram(); + ::mindspore::irpb::Summary_Histogram* mutable_histogram(); + void set_allocated_histogram(::mindspore::irpb::Summary_Histogram* histogram); + private: + const ::mindspore::irpb::Summary_Histogram& _internal_histogram() const; + ::mindspore::irpb::Summary_Histogram* _internal_mutable_histogram(); + public: + void unsafe_arena_set_allocated_histogram( + ::mindspore::irpb::Summary_Histogram* histogram); + ::mindspore::irpb::Summary_Histogram* unsafe_arena_release_histogram(); + + // .mindspore.irpb.LossLandscape loss_landscape = 10; + bool has_loss_landscape() const; + private: + bool _internal_has_loss_landscape() const; + public: + void clear_loss_landscape(); + const ::mindspore::irpb::LossLandscape& loss_landscape() const; + ::mindspore::irpb::LossLandscape* release_loss_landscape(); + ::mindspore::irpb::LossLandscape* mutable_loss_landscape(); + void set_allocated_loss_landscape(::mindspore::irpb::LossLandscape* loss_landscape); + private: + const ::mindspore::irpb::LossLandscape& _internal_loss_landscape() const; + ::mindspore::irpb::LossLandscape* _internal_mutable_loss_landscape(); + public: + void unsafe_arena_set_allocated_loss_landscape( + ::mindspore::irpb::LossLandscape* loss_landscape); + ::mindspore::irpb::LossLandscape* unsafe_arena_release_loss_landscape(); + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:mindspore.irpb.Summary.Value) + private: + class _Internal; + void set_has_scalar_value(); + void set_has_image(); + void set_has_tensor(); + void set_has_histogram(); + void set_has_loss_landscape(); + + inline bool has_value() const; + inline void clear_has_value(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr tag_; + union ValueUnion { + ValueUnion() {} + float scalar_value_; + ::mindspore::irpb::Summary_Image* image_; + ::mindspore::irpb::TensorProto* tensor_; + ::mindspore::irpb::Summary_Histogram* histogram_; + ::mindspore::irpb::LossLandscape* loss_landscape_; + } value_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Summary) */ { + public: + inline Summary() : Summary(nullptr) {} + virtual ~Summary(); + + Summary(const Summary& from); + Summary(Summary&& from) noexcept + : Summary() { + *this = ::std::move(from); + } + + inline Summary& operator=(const Summary& from) { + CopyFrom(from); + return *this; + } + inline Summary& operator=(Summary&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary* internal_default_instance() { + return reinterpret_cast( + &_Summary_default_instance_); + } + static constexpr int kIndexInFileMessages = + 9; + + friend void swap(Summary& a, Summary& b) { + a.Swap(&b); + } + inline void Swap(Summary* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary& from); + void MergeFrom(const Summary& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Summary"; + } + protected: + explicit Summary(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef Summary_Image Image; + typedef Summary_Histogram Histogram; + typedef Summary_Value Value; + + // accessors ------------------------------------------------------- + + enum : int { + kValueFieldNumber = 1, + }; + // repeated .mindspore.irpb.Summary.Value value = 1; + int value_size() const; + private: + int _internal_value_size() const; + public: + void clear_value(); + ::mindspore::irpb::Summary_Value* mutable_value(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Value >* + mutable_value(); + private: + const ::mindspore::irpb::Summary_Value& _internal_value(int index) const; + ::mindspore::irpb::Summary_Value* _internal_add_value(); + public: + const ::mindspore::irpb::Summary_Value& value(int index) const; + ::mindspore::irpb::Summary_Value* add_value(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Value >& + value() const; + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Summary) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Value > value_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain_Inference PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain.Inference) */ { + public: + inline Explain_Inference() : Explain_Inference(nullptr) {} + virtual ~Explain_Inference(); + + Explain_Inference(const Explain_Inference& from); + Explain_Inference(Explain_Inference&& from) noexcept + : Explain_Inference() { + *this = ::std::move(from); + } + + inline Explain_Inference& operator=(const Explain_Inference& from) { + CopyFrom(from); + return *this; + } + inline Explain_Inference& operator=(Explain_Inference&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain_Inference& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain_Inference* internal_default_instance() { + return reinterpret_cast( + &_Explain_Inference_default_instance_); + } + static constexpr int kIndexInFileMessages = + 10; + + friend void swap(Explain_Inference& a, Explain_Inference& b) { + a.Swap(&b); + } + inline void Swap(Explain_Inference* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain_Inference* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain_Inference* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain_Inference* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain_Inference& from); + void MergeFrom(const Explain_Inference& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain_Inference* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain.Inference"; + } + protected: + explicit Explain_Inference(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kGroundTruthProbFieldNumber = 1, + kPredictedLabelFieldNumber = 2, + kPredictedProbFieldNumber = 3, + kGroundTruthProbSdFieldNumber = 4, + kGroundTruthProbItl95LowFieldNumber = 5, + kGroundTruthProbItl95HiFieldNumber = 6, + kPredictedProbSdFieldNumber = 7, + kPredictedProbItl95LowFieldNumber = 8, + kPredictedProbItl95HiFieldNumber = 9, + }; + // repeated float ground_truth_prob = 1; + int ground_truth_prob_size() const; + private: + int _internal_ground_truth_prob_size() const; + public: + void clear_ground_truth_prob(); + private: + float _internal_ground_truth_prob(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_ground_truth_prob() const; + void _internal_add_ground_truth_prob(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_ground_truth_prob(); + public: + float ground_truth_prob(int index) const; + void set_ground_truth_prob(int index, float value); + void add_ground_truth_prob(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + ground_truth_prob() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_ground_truth_prob(); + + // repeated int32 predicted_label = 2; + int predicted_label_size() const; + private: + int _internal_predicted_label_size() const; + public: + void clear_predicted_label(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_predicted_label(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_predicted_label() const; + void _internal_add_predicted_label(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_predicted_label(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 predicted_label(int index) const; + void set_predicted_label(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_predicted_label(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + predicted_label() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_predicted_label(); + + // repeated float predicted_prob = 3; + int predicted_prob_size() const; + private: + int _internal_predicted_prob_size() const; + public: + void clear_predicted_prob(); + private: + float _internal_predicted_prob(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_predicted_prob() const; + void _internal_add_predicted_prob(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_predicted_prob(); + public: + float predicted_prob(int index) const; + void set_predicted_prob(int index, float value); + void add_predicted_prob(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + predicted_prob() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_predicted_prob(); + + // repeated float ground_truth_prob_sd = 4; + int ground_truth_prob_sd_size() const; + private: + int _internal_ground_truth_prob_sd_size() const; + public: + void clear_ground_truth_prob_sd(); + private: + float _internal_ground_truth_prob_sd(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_ground_truth_prob_sd() const; + void _internal_add_ground_truth_prob_sd(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_ground_truth_prob_sd(); + public: + float ground_truth_prob_sd(int index) const; + void set_ground_truth_prob_sd(int index, float value); + void add_ground_truth_prob_sd(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + ground_truth_prob_sd() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_ground_truth_prob_sd(); + + // repeated float ground_truth_prob_itl95_low = 5; + int ground_truth_prob_itl95_low_size() const; + private: + int _internal_ground_truth_prob_itl95_low_size() const; + public: + void clear_ground_truth_prob_itl95_low(); + private: + float _internal_ground_truth_prob_itl95_low(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_ground_truth_prob_itl95_low() const; + void _internal_add_ground_truth_prob_itl95_low(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_ground_truth_prob_itl95_low(); + public: + float ground_truth_prob_itl95_low(int index) const; + void set_ground_truth_prob_itl95_low(int index, float value); + void add_ground_truth_prob_itl95_low(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + ground_truth_prob_itl95_low() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_ground_truth_prob_itl95_low(); + + // repeated float ground_truth_prob_itl95_hi = 6; + int ground_truth_prob_itl95_hi_size() const; + private: + int _internal_ground_truth_prob_itl95_hi_size() const; + public: + void clear_ground_truth_prob_itl95_hi(); + private: + float _internal_ground_truth_prob_itl95_hi(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_ground_truth_prob_itl95_hi() const; + void _internal_add_ground_truth_prob_itl95_hi(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_ground_truth_prob_itl95_hi(); + public: + float ground_truth_prob_itl95_hi(int index) const; + void set_ground_truth_prob_itl95_hi(int index, float value); + void add_ground_truth_prob_itl95_hi(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + ground_truth_prob_itl95_hi() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_ground_truth_prob_itl95_hi(); + + // repeated float predicted_prob_sd = 7; + int predicted_prob_sd_size() const; + private: + int _internal_predicted_prob_sd_size() const; + public: + void clear_predicted_prob_sd(); + private: + float _internal_predicted_prob_sd(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_predicted_prob_sd() const; + void _internal_add_predicted_prob_sd(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_predicted_prob_sd(); + public: + float predicted_prob_sd(int index) const; + void set_predicted_prob_sd(int index, float value); + void add_predicted_prob_sd(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + predicted_prob_sd() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_predicted_prob_sd(); + + // repeated float predicted_prob_itl95_low = 8; + int predicted_prob_itl95_low_size() const; + private: + int _internal_predicted_prob_itl95_low_size() const; + public: + void clear_predicted_prob_itl95_low(); + private: + float _internal_predicted_prob_itl95_low(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_predicted_prob_itl95_low() const; + void _internal_add_predicted_prob_itl95_low(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_predicted_prob_itl95_low(); + public: + float predicted_prob_itl95_low(int index) const; + void set_predicted_prob_itl95_low(int index, float value); + void add_predicted_prob_itl95_low(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + predicted_prob_itl95_low() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_predicted_prob_itl95_low(); + + // repeated float predicted_prob_itl95_hi = 9; + int predicted_prob_itl95_hi_size() const; + private: + int _internal_predicted_prob_itl95_hi_size() const; + public: + void clear_predicted_prob_itl95_hi(); + private: + float _internal_predicted_prob_itl95_hi(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_predicted_prob_itl95_hi() const; + void _internal_add_predicted_prob_itl95_hi(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_predicted_prob_itl95_hi(); + public: + float predicted_prob_itl95_hi(int index) const; + void set_predicted_prob_itl95_hi(int index, float value); + void add_predicted_prob_itl95_hi(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + predicted_prob_itl95_hi() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_predicted_prob_itl95_hi(); + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain.Inference) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > ground_truth_prob_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > predicted_label_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > predicted_prob_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > ground_truth_prob_sd_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > ground_truth_prob_itl95_low_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > ground_truth_prob_itl95_hi_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > predicted_prob_sd_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > predicted_prob_itl95_low_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > predicted_prob_itl95_hi_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain_Explanation PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain.Explanation) */ { + public: + inline Explain_Explanation() : Explain_Explanation(nullptr) {} + virtual ~Explain_Explanation(); + + Explain_Explanation(const Explain_Explanation& from); + Explain_Explanation(Explain_Explanation&& from) noexcept + : Explain_Explanation() { + *this = ::std::move(from); + } + + inline Explain_Explanation& operator=(const Explain_Explanation& from) { + CopyFrom(from); + return *this; + } + inline Explain_Explanation& operator=(Explain_Explanation&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain_Explanation& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain_Explanation* internal_default_instance() { + return reinterpret_cast( + &_Explain_Explanation_default_instance_); + } + static constexpr int kIndexInFileMessages = + 11; + + friend void swap(Explain_Explanation& a, Explain_Explanation& b) { + a.Swap(&b); + } + inline void Swap(Explain_Explanation* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain_Explanation* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain_Explanation* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain_Explanation* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain_Explanation& from); + void MergeFrom(const Explain_Explanation& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain_Explanation* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain.Explanation"; + } + protected: + explicit Explain_Explanation(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kExplainMethodFieldNumber = 1, + kHeatmapPathFieldNumber = 3, + kLabelFieldNumber = 2, + }; + // optional string explain_method = 1; + bool has_explain_method() const; + private: + bool _internal_has_explain_method() const; + public: + void clear_explain_method(); + const std::string& explain_method() const; + void set_explain_method(const std::string& value); + void set_explain_method(std::string&& value); + void set_explain_method(const char* value); + void set_explain_method(const char* value, size_t size); + std::string* mutable_explain_method(); + std::string* release_explain_method(); + void set_allocated_explain_method(std::string* explain_method); + private: + const std::string& _internal_explain_method() const; + void _internal_set_explain_method(const std::string& value); + std::string* _internal_mutable_explain_method(); + public: + + // optional string heatmap_path = 3; + bool has_heatmap_path() const; + private: + bool _internal_has_heatmap_path() const; + public: + void clear_heatmap_path(); + const std::string& heatmap_path() const; + void set_heatmap_path(const std::string& value); + void set_heatmap_path(std::string&& value); + void set_heatmap_path(const char* value); + void set_heatmap_path(const char* value, size_t size); + std::string* mutable_heatmap_path(); + std::string* release_heatmap_path(); + void set_allocated_heatmap_path(std::string* heatmap_path); + private: + const std::string& _internal_heatmap_path() const; + void _internal_set_heatmap_path(const std::string& value); + std::string* _internal_mutable_heatmap_path(); + public: + + // optional int32 label = 2; + bool has_label() const; + private: + bool _internal_has_label() const; + public: + void clear_label(); + ::PROTOBUF_NAMESPACE_ID::int32 label() const; + void set_label(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_label() const; + void _internal_set_label(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain.Explanation) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr explain_method_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr heatmap_path_; + ::PROTOBUF_NAMESPACE_ID::int32 label_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain_Benchmark PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain.Benchmark) */ { + public: + inline Explain_Benchmark() : Explain_Benchmark(nullptr) {} + virtual ~Explain_Benchmark(); + + Explain_Benchmark(const Explain_Benchmark& from); + Explain_Benchmark(Explain_Benchmark&& from) noexcept + : Explain_Benchmark() { + *this = ::std::move(from); + } + + inline Explain_Benchmark& operator=(const Explain_Benchmark& from) { + CopyFrom(from); + return *this; + } + inline Explain_Benchmark& operator=(Explain_Benchmark&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain_Benchmark& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain_Benchmark* internal_default_instance() { + return reinterpret_cast( + &_Explain_Benchmark_default_instance_); + } + static constexpr int kIndexInFileMessages = + 12; + + friend void swap(Explain_Benchmark& a, Explain_Benchmark& b) { + a.Swap(&b); + } + inline void Swap(Explain_Benchmark* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain_Benchmark* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain_Benchmark* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain_Benchmark* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain_Benchmark& from); + void MergeFrom(const Explain_Benchmark& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain_Benchmark* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain.Benchmark"; + } + protected: + explicit Explain_Benchmark(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kLabelScoreFieldNumber = 4, + kBenchmarkMethodFieldNumber = 1, + kExplainMethodFieldNumber = 2, + kTotalScoreFieldNumber = 3, + }; + // repeated float label_score = 4; + int label_score_size() const; + private: + int _internal_label_score_size() const; + public: + void clear_label_score(); + private: + float _internal_label_score(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_label_score() const; + void _internal_add_label_score(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_label_score(); + public: + float label_score(int index) const; + void set_label_score(int index, float value); + void add_label_score(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + label_score() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_label_score(); + + // optional string benchmark_method = 1; + bool has_benchmark_method() const; + private: + bool _internal_has_benchmark_method() const; + public: + void clear_benchmark_method(); + const std::string& benchmark_method() const; + void set_benchmark_method(const std::string& value); + void set_benchmark_method(std::string&& value); + void set_benchmark_method(const char* value); + void set_benchmark_method(const char* value, size_t size); + std::string* mutable_benchmark_method(); + std::string* release_benchmark_method(); + void set_allocated_benchmark_method(std::string* benchmark_method); + private: + const std::string& _internal_benchmark_method() const; + void _internal_set_benchmark_method(const std::string& value); + std::string* _internal_mutable_benchmark_method(); + public: + + // optional string explain_method = 2; + bool has_explain_method() const; + private: + bool _internal_has_explain_method() const; + public: + void clear_explain_method(); + const std::string& explain_method() const; + void set_explain_method(const std::string& value); + void set_explain_method(std::string&& value); + void set_explain_method(const char* value); + void set_explain_method(const char* value, size_t size); + std::string* mutable_explain_method(); + std::string* release_explain_method(); + void set_allocated_explain_method(std::string* explain_method); + private: + const std::string& _internal_explain_method() const; + void _internal_set_explain_method(const std::string& value); + std::string* _internal_mutable_explain_method(); + public: + + // optional float total_score = 3; + bool has_total_score() const; + private: + bool _internal_has_total_score() const; + public: + void clear_total_score(); + float total_score() const; + void set_total_score(float value); + private: + float _internal_total_score() const; + void _internal_set_total_score(float value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain.Benchmark) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > label_score_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr benchmark_method_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr explain_method_; + float total_score_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain_Metadata PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain.Metadata) */ { + public: + inline Explain_Metadata() : Explain_Metadata(nullptr) {} + virtual ~Explain_Metadata(); + + Explain_Metadata(const Explain_Metadata& from); + Explain_Metadata(Explain_Metadata&& from) noexcept + : Explain_Metadata() { + *this = ::std::move(from); + } + + inline Explain_Metadata& operator=(const Explain_Metadata& from) { + CopyFrom(from); + return *this; + } + inline Explain_Metadata& operator=(Explain_Metadata&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain_Metadata& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain_Metadata* internal_default_instance() { + return reinterpret_cast( + &_Explain_Metadata_default_instance_); + } + static constexpr int kIndexInFileMessages = + 13; + + friend void swap(Explain_Metadata& a, Explain_Metadata& b) { + a.Swap(&b); + } + inline void Swap(Explain_Metadata* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain_Metadata* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain_Metadata* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain_Metadata* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain_Metadata& from); + void MergeFrom(const Explain_Metadata& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain_Metadata* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain.Metadata"; + } + protected: + explicit Explain_Metadata(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kLabelFieldNumber = 1, + kExplainMethodFieldNumber = 2, + kBenchmarkMethodFieldNumber = 3, + }; + // repeated string label = 1; + int label_size() const; + private: + int _internal_label_size() const; + public: + void clear_label(); + const std::string& label(int index) const; + std::string* mutable_label(int index); + void set_label(int index, const std::string& value); + void set_label(int index, std::string&& value); + void set_label(int index, const char* value); + void set_label(int index, const char* value, size_t size); + std::string* add_label(); + void add_label(const std::string& value); + void add_label(std::string&& value); + void add_label(const char* value); + void add_label(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& label() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_label(); + private: + const std::string& _internal_label(int index) const; + std::string* _internal_add_label(); + public: + + // repeated string explain_method = 2; + int explain_method_size() const; + private: + int _internal_explain_method_size() const; + public: + void clear_explain_method(); + const std::string& explain_method(int index) const; + std::string* mutable_explain_method(int index); + void set_explain_method(int index, const std::string& value); + void set_explain_method(int index, std::string&& value); + void set_explain_method(int index, const char* value); + void set_explain_method(int index, const char* value, size_t size); + std::string* add_explain_method(); + void add_explain_method(const std::string& value); + void add_explain_method(std::string&& value); + void add_explain_method(const char* value); + void add_explain_method(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& explain_method() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_explain_method(); + private: + const std::string& _internal_explain_method(int index) const; + std::string* _internal_add_explain_method(); + public: + + // repeated string benchmark_method = 3; + int benchmark_method_size() const; + private: + int _internal_benchmark_method_size() const; + public: + void clear_benchmark_method(); + const std::string& benchmark_method(int index) const; + std::string* mutable_benchmark_method(int index); + void set_benchmark_method(int index, const std::string& value); + void set_benchmark_method(int index, std::string&& value); + void set_benchmark_method(int index, const char* value); + void set_benchmark_method(int index, const char* value, size_t size); + std::string* add_benchmark_method(); + void add_benchmark_method(const std::string& value); + void add_benchmark_method(std::string&& value); + void add_benchmark_method(const char* value); + void add_benchmark_method(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& benchmark_method() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_benchmark_method(); + private: + const std::string& _internal_benchmark_method(int index) const; + std::string* _internal_add_benchmark_method(); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain.Metadata) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField label_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField explain_method_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField benchmark_method_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain_HocLayer PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain.HocLayer) */ { + public: + inline Explain_HocLayer() : Explain_HocLayer(nullptr) {} + virtual ~Explain_HocLayer(); + + Explain_HocLayer(const Explain_HocLayer& from); + Explain_HocLayer(Explain_HocLayer&& from) noexcept + : Explain_HocLayer() { + *this = ::std::move(from); + } + + inline Explain_HocLayer& operator=(const Explain_HocLayer& from) { + CopyFrom(from); + return *this; + } + inline Explain_HocLayer& operator=(Explain_HocLayer&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain_HocLayer& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain_HocLayer* internal_default_instance() { + return reinterpret_cast( + &_Explain_HocLayer_default_instance_); + } + static constexpr int kIndexInFileMessages = + 14; + + friend void swap(Explain_HocLayer& a, Explain_HocLayer& b) { + a.Swap(&b); + } + inline void Swap(Explain_HocLayer* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain_HocLayer* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain_HocLayer* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain_HocLayer* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain_HocLayer& from); + void MergeFrom(const Explain_HocLayer& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain_HocLayer* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain.HocLayer"; + } + protected: + explicit Explain_HocLayer(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kBoxFieldNumber = 2, + kProbFieldNumber = 1, + }; + // repeated int32 box = 2; + int box_size() const; + private: + int _internal_box_size() const; + public: + void clear_box(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_box(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_box() const; + void _internal_add_box(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_box(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 box(int index) const; + void set_box(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_box(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + box() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_box(); + + // optional float prob = 1; + bool has_prob() const; + private: + bool _internal_has_prob() const; + public: + void clear_prob(); + float prob() const; + void set_prob(float value); + private: + float _internal_prob() const; + void _internal_set_prob(float value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain.HocLayer) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > box_; + float prob_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain_Hoc PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain.Hoc) */ { + public: + inline Explain_Hoc() : Explain_Hoc(nullptr) {} + virtual ~Explain_Hoc(); + + Explain_Hoc(const Explain_Hoc& from); + Explain_Hoc(Explain_Hoc&& from) noexcept + : Explain_Hoc() { + *this = ::std::move(from); + } + + inline Explain_Hoc& operator=(const Explain_Hoc& from) { + CopyFrom(from); + return *this; + } + inline Explain_Hoc& operator=(Explain_Hoc&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain_Hoc& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain_Hoc* internal_default_instance() { + return reinterpret_cast( + &_Explain_Hoc_default_instance_); + } + static constexpr int kIndexInFileMessages = + 15; + + friend void swap(Explain_Hoc& a, Explain_Hoc& b) { + a.Swap(&b); + } + inline void Swap(Explain_Hoc* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain_Hoc* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain_Hoc* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain_Hoc* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain_Hoc& from); + void MergeFrom(const Explain_Hoc& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain_Hoc* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain.Hoc"; + } + protected: + explicit Explain_Hoc(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kLayerFieldNumber = 3, + kMaskFieldNumber = 2, + kLabelFieldNumber = 1, + }; + // repeated .mindspore.irpb.Explain.HocLayer layer = 3; + int layer_size() const; + private: + int _internal_layer_size() const; + public: + void clear_layer(); + ::mindspore::irpb::Explain_HocLayer* mutable_layer(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_HocLayer >* + mutable_layer(); + private: + const ::mindspore::irpb::Explain_HocLayer& _internal_layer(int index) const; + ::mindspore::irpb::Explain_HocLayer* _internal_add_layer(); + public: + const ::mindspore::irpb::Explain_HocLayer& layer(int index) const; + ::mindspore::irpb::Explain_HocLayer* add_layer(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_HocLayer >& + layer() const; + + // optional string mask = 2; + bool has_mask() const; + private: + bool _internal_has_mask() const; + public: + void clear_mask(); + const std::string& mask() const; + void set_mask(const std::string& value); + void set_mask(std::string&& value); + void set_mask(const char* value); + void set_mask(const char* value, size_t size); + std::string* mutable_mask(); + std::string* release_mask(); + void set_allocated_mask(std::string* mask); + private: + const std::string& _internal_mask() const; + void _internal_set_mask(const std::string& value); + std::string* _internal_mutable_mask(); + public: + + // optional int32 label = 1; + bool has_label() const; + private: + bool _internal_has_label() const; + public: + void clear_label(); + ::PROTOBUF_NAMESPACE_ID::int32 label() const; + void set_label(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_label() const; + void _internal_set_label(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain.Hoc) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_HocLayer > layer_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr mask_; + ::PROTOBUF_NAMESPACE_ID::int32 label_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// ------------------------------------------------------------------- + +class Explain PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:mindspore.irpb.Explain) */ { + public: + inline Explain() : Explain(nullptr) {} + virtual ~Explain(); + + Explain(const Explain& from); + Explain(Explain&& from) noexcept + : Explain() { + *this = ::std::move(from); + } + + inline Explain& operator=(const Explain& from) { + CopyFrom(from); + return *this; + } + inline Explain& operator=(Explain&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Explain& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Explain* internal_default_instance() { + return reinterpret_cast( + &_Explain_default_instance_); + } + static constexpr int kIndexInFileMessages = + 16; + + friend void swap(Explain& a, Explain& b) { + a.Swap(&b); + } + inline void Swap(Explain* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Explain* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Explain* New() const final { + return CreateMaybeMessage(nullptr); + } + + Explain* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Explain& from); + void MergeFrom(const Explain& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Explain* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "mindspore.irpb.Explain"; + } + protected: + explicit Explain(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_mindspore_5fsummary_2eproto); + return ::descriptor_table_mindspore_5fsummary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef Explain_Inference Inference; + typedef Explain_Explanation Explanation; + typedef Explain_Benchmark Benchmark; + typedef Explain_Metadata Metadata; + typedef Explain_HocLayer HocLayer; + typedef Explain_Hoc Hoc; + + // accessors ------------------------------------------------------- + + enum : int { + kGroundTruthLabelFieldNumber = 3, + kExplanationFieldNumber = 5, + kBenchmarkFieldNumber = 6, + kHocFieldNumber = 9, + kImagePathFieldNumber = 2, + kStatusFieldNumber = 8, + kInferenceFieldNumber = 4, + kMetadataFieldNumber = 7, + kSampleIdFieldNumber = 1, + }; + // repeated int32 ground_truth_label = 3; + int ground_truth_label_size() const; + private: + int _internal_ground_truth_label_size() const; + public: + void clear_ground_truth_label(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_ground_truth_label(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_ground_truth_label() const; + void _internal_add_ground_truth_label(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_ground_truth_label(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 ground_truth_label(int index) const; + void set_ground_truth_label(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_ground_truth_label(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + ground_truth_label() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_ground_truth_label(); + + // repeated .mindspore.irpb.Explain.Explanation explanation = 5; + int explanation_size() const; + private: + int _internal_explanation_size() const; + public: + void clear_explanation(); + ::mindspore::irpb::Explain_Explanation* mutable_explanation(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Explanation >* + mutable_explanation(); + private: + const ::mindspore::irpb::Explain_Explanation& _internal_explanation(int index) const; + ::mindspore::irpb::Explain_Explanation* _internal_add_explanation(); + public: + const ::mindspore::irpb::Explain_Explanation& explanation(int index) const; + ::mindspore::irpb::Explain_Explanation* add_explanation(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Explanation >& + explanation() const; + + // repeated .mindspore.irpb.Explain.Benchmark benchmark = 6; + int benchmark_size() const; + private: + int _internal_benchmark_size() const; + public: + void clear_benchmark(); + ::mindspore::irpb::Explain_Benchmark* mutable_benchmark(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Benchmark >* + mutable_benchmark(); + private: + const ::mindspore::irpb::Explain_Benchmark& _internal_benchmark(int index) const; + ::mindspore::irpb::Explain_Benchmark* _internal_add_benchmark(); + public: + const ::mindspore::irpb::Explain_Benchmark& benchmark(int index) const; + ::mindspore::irpb::Explain_Benchmark* add_benchmark(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Benchmark >& + benchmark() const; + + // repeated .mindspore.irpb.Explain.Hoc hoc = 9; + int hoc_size() const; + private: + int _internal_hoc_size() const; + public: + void clear_hoc(); + ::mindspore::irpb::Explain_Hoc* mutable_hoc(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Hoc >* + mutable_hoc(); + private: + const ::mindspore::irpb::Explain_Hoc& _internal_hoc(int index) const; + ::mindspore::irpb::Explain_Hoc* _internal_add_hoc(); + public: + const ::mindspore::irpb::Explain_Hoc& hoc(int index) const; + ::mindspore::irpb::Explain_Hoc* add_hoc(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Hoc >& + hoc() const; + + // optional string image_path = 2; + bool has_image_path() const; + private: + bool _internal_has_image_path() const; + public: + void clear_image_path(); + const std::string& image_path() const; + void set_image_path(const std::string& value); + void set_image_path(std::string&& value); + void set_image_path(const char* value); + void set_image_path(const char* value, size_t size); + std::string* mutable_image_path(); + std::string* release_image_path(); + void set_allocated_image_path(std::string* image_path); + private: + const std::string& _internal_image_path() const; + void _internal_set_image_path(const std::string& value); + std::string* _internal_mutable_image_path(); + public: + + // optional string status = 8; + bool has_status() const; + private: + bool _internal_has_status() const; + public: + void clear_status(); + const std::string& status() const; + void set_status(const std::string& value); + void set_status(std::string&& value); + void set_status(const char* value); + void set_status(const char* value, size_t size); + std::string* mutable_status(); + std::string* release_status(); + void set_allocated_status(std::string* status); + private: + const std::string& _internal_status() const; + void _internal_set_status(const std::string& value); + std::string* _internal_mutable_status(); + public: + + // optional .mindspore.irpb.Explain.Inference inference = 4; + bool has_inference() const; + private: + bool _internal_has_inference() const; + public: + void clear_inference(); + const ::mindspore::irpb::Explain_Inference& inference() const; + ::mindspore::irpb::Explain_Inference* release_inference(); + ::mindspore::irpb::Explain_Inference* mutable_inference(); + void set_allocated_inference(::mindspore::irpb::Explain_Inference* inference); + private: + const ::mindspore::irpb::Explain_Inference& _internal_inference() const; + ::mindspore::irpb::Explain_Inference* _internal_mutable_inference(); + public: + void unsafe_arena_set_allocated_inference( + ::mindspore::irpb::Explain_Inference* inference); + ::mindspore::irpb::Explain_Inference* unsafe_arena_release_inference(); + + // optional .mindspore.irpb.Explain.Metadata metadata = 7; + bool has_metadata() const; + private: + bool _internal_has_metadata() const; + public: + void clear_metadata(); + const ::mindspore::irpb::Explain_Metadata& metadata() const; + ::mindspore::irpb::Explain_Metadata* release_metadata(); + ::mindspore::irpb::Explain_Metadata* mutable_metadata(); + void set_allocated_metadata(::mindspore::irpb::Explain_Metadata* metadata); + private: + const ::mindspore::irpb::Explain_Metadata& _internal_metadata() const; + ::mindspore::irpb::Explain_Metadata* _internal_mutable_metadata(); + public: + void unsafe_arena_set_allocated_metadata( + ::mindspore::irpb::Explain_Metadata* metadata); + ::mindspore::irpb::Explain_Metadata* unsafe_arena_release_metadata(); + + // optional int32 sample_id = 1; + bool has_sample_id() const; + private: + bool _internal_has_sample_id() const; + public: + void clear_sample_id(); + ::PROTOBUF_NAMESPACE_ID::int32 sample_id() const; + void set_sample_id(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_sample_id() const; + void _internal_set_sample_id(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:mindspore.irpb.Explain) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > ground_truth_label_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Explanation > explanation_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Benchmark > benchmark_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Hoc > hoc_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr image_path_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr status_; + ::mindspore::irpb::Explain_Inference* inference_; + ::mindspore::irpb::Explain_Metadata* metadata_; + ::PROTOBUF_NAMESPACE_ID::int32 sample_id_; + friend struct ::TableStruct_mindspore_5fsummary_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// Event + +// required double wall_time = 1; +inline bool Event::_internal_has_wall_time() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Event::has_wall_time() const { + return _internal_has_wall_time(); +} +inline void Event::clear_wall_time() { + wall_time_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline double Event::_internal_wall_time() const { + return wall_time_; +} +inline double Event::wall_time() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Event.wall_time) + return _internal_wall_time(); +} +inline void Event::_internal_set_wall_time(double value) { + _has_bits_[0] |= 0x00000001u; + wall_time_ = value; +} +inline void Event::set_wall_time(double value) { + _internal_set_wall_time(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Event.wall_time) +} + +// optional int64 step = 2; +inline bool Event::_internal_has_step() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Event::has_step() const { + return _internal_has_step(); +} +inline void Event::clear_step() { + step_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Event::_internal_step() const { + return step_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Event::step() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Event.step) + return _internal_step(); +} +inline void Event::_internal_set_step(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + step_ = value; +} +inline void Event::set_step(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_step(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Event.step) +} + +// string version = 3; +inline bool Event::_internal_has_version() const { + return what_case() == kVersion; +} +inline bool Event::has_version() const { + return _internal_has_version(); +} +inline void Event::set_has_version() { + _oneof_case_[0] = kVersion; +} +inline void Event::clear_version() { + if (_internal_has_version()) { + what_.version_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_what(); + } +} +inline const std::string& Event::version() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Event.version) + return _internal_version(); +} +inline void Event::set_version(const std::string& value) { + _internal_set_version(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Event.version) +} +inline std::string* Event::mutable_version() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Event.version) + return _internal_mutable_version(); +} +inline const std::string& Event::_internal_version() const { + if (_internal_has_version()) { + return what_.version_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void Event::_internal_set_version(const std::string& value) { + if (!_internal_has_version()) { + clear_what(); + set_has_version(); + what_.version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Event::set_version(std::string&& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Event.version) + if (!_internal_has_version()) { + clear_what(); + set_has_version(); + what_.version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.version_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Event.version) +} +inline void Event::set_version(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_version()) { + clear_what(); + set_has_version(); + what_.version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Event.version) +} +inline void Event::set_version(const char* value, + size_t size) { + if (!_internal_has_version()) { + clear_what(); + set_has_version(); + what_.version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + what_.version_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Event.version) +} +inline std::string* Event::_internal_mutable_version() { + if (!_internal_has_version()) { + clear_what(); + set_has_version(); + what_.version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return what_.version_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Event::release_version() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Event.version) + if (_internal_has_version()) { + clear_has_what(); + return what_.version_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void Event::set_allocated_version(std::string* version) { + if (has_what()) { + clear_what(); + } + if (version != nullptr) { + set_has_version(); + what_.version_.UnsafeSetDefault(version); + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); + if (arena != nullptr) { + arena->Own(version); + } + } + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Event.version) +} + +// .mindspore.irpb.GraphProto graph_def = 4; +inline bool Event::_internal_has_graph_def() const { + return what_case() == kGraphDef; +} +inline bool Event::has_graph_def() const { + return _internal_has_graph_def(); +} +inline void Event::set_has_graph_def() { + _oneof_case_[0] = kGraphDef; +} +inline ::mindspore::irpb::GraphProto* Event::release_graph_def() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Event.graph_def) + if (_internal_has_graph_def()) { + clear_has_what(); + ::mindspore::irpb::GraphProto* temp = what_.graph_def_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.graph_def_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::GraphProto& Event::_internal_graph_def() const { + return _internal_has_graph_def() + ? *what_.graph_def_ + : *reinterpret_cast< ::mindspore::irpb::GraphProto*>(&::mindspore::irpb::_GraphProto_default_instance_); +} +inline const ::mindspore::irpb::GraphProto& Event::graph_def() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Event.graph_def) + return _internal_graph_def(); +} +inline ::mindspore::irpb::GraphProto* Event::unsafe_arena_release_graph_def() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Event.graph_def) + if (_internal_has_graph_def()) { + clear_has_what(); + ::mindspore::irpb::GraphProto* temp = what_.graph_def_; + what_.graph_def_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_graph_def(::mindspore::irpb::GraphProto* graph_def) { + clear_what(); + if (graph_def) { + set_has_graph_def(); + what_.graph_def_ = graph_def; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Event.graph_def) +} +inline ::mindspore::irpb::GraphProto* Event::_internal_mutable_graph_def() { + if (!_internal_has_graph_def()) { + clear_what(); + set_has_graph_def(); + what_.graph_def_ = CreateMaybeMessage< ::mindspore::irpb::GraphProto >(GetArena()); + } + return what_.graph_def_; +} +inline ::mindspore::irpb::GraphProto* Event::mutable_graph_def() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Event.graph_def) + return _internal_mutable_graph_def(); +} + +// .mindspore.irpb.Summary summary = 5; +inline bool Event::_internal_has_summary() const { + return what_case() == kSummary; +} +inline bool Event::has_summary() const { + return _internal_has_summary(); +} +inline void Event::set_has_summary() { + _oneof_case_[0] = kSummary; +} +inline void Event::clear_summary() { + if (_internal_has_summary()) { + if (GetArena() == nullptr) { + delete what_.summary_; + } + clear_has_what(); + } +} +inline ::mindspore::irpb::Summary* Event::release_summary() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Event.summary) + if (_internal_has_summary()) { + clear_has_what(); + ::mindspore::irpb::Summary* temp = what_.summary_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.summary_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::Summary& Event::_internal_summary() const { + return _internal_has_summary() + ? *what_.summary_ + : *reinterpret_cast< ::mindspore::irpb::Summary*>(&::mindspore::irpb::_Summary_default_instance_); +} +inline const ::mindspore::irpb::Summary& Event::summary() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Event.summary) + return _internal_summary(); +} +inline ::mindspore::irpb::Summary* Event::unsafe_arena_release_summary() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Event.summary) + if (_internal_has_summary()) { + clear_has_what(); + ::mindspore::irpb::Summary* temp = what_.summary_; + what_.summary_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_summary(::mindspore::irpb::Summary* summary) { + clear_what(); + if (summary) { + set_has_summary(); + what_.summary_ = summary; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Event.summary) +} +inline ::mindspore::irpb::Summary* Event::_internal_mutable_summary() { + if (!_internal_has_summary()) { + clear_what(); + set_has_summary(); + what_.summary_ = CreateMaybeMessage< ::mindspore::irpb::Summary >(GetArena()); + } + return what_.summary_; +} +inline ::mindspore::irpb::Summary* Event::mutable_summary() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Event.summary) + return _internal_mutable_summary(); +} + +// .mindspore.irpb.Explain explain = 6; +inline bool Event::_internal_has_explain() const { + return what_case() == kExplain; +} +inline bool Event::has_explain() const { + return _internal_has_explain(); +} +inline void Event::set_has_explain() { + _oneof_case_[0] = kExplain; +} +inline void Event::clear_explain() { + if (_internal_has_explain()) { + if (GetArena() == nullptr) { + delete what_.explain_; + } + clear_has_what(); + } +} +inline ::mindspore::irpb::Explain* Event::release_explain() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Event.explain) + if (_internal_has_explain()) { + clear_has_what(); + ::mindspore::irpb::Explain* temp = what_.explain_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + what_.explain_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::Explain& Event::_internal_explain() const { + return _internal_has_explain() + ? *what_.explain_ + : *reinterpret_cast< ::mindspore::irpb::Explain*>(&::mindspore::irpb::_Explain_default_instance_); +} +inline const ::mindspore::irpb::Explain& Event::explain() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Event.explain) + return _internal_explain(); +} +inline ::mindspore::irpb::Explain* Event::unsafe_arena_release_explain() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Event.explain) + if (_internal_has_explain()) { + clear_has_what(); + ::mindspore::irpb::Explain* temp = what_.explain_; + what_.explain_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Event::unsafe_arena_set_allocated_explain(::mindspore::irpb::Explain* explain) { + clear_what(); + if (explain) { + set_has_explain(); + what_.explain_ = explain; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Event.explain) +} +inline ::mindspore::irpb::Explain* Event::_internal_mutable_explain() { + if (!_internal_has_explain()) { + clear_what(); + set_has_explain(); + what_.explain_ = CreateMaybeMessage< ::mindspore::irpb::Explain >(GetArena()); + } + return what_.explain_; +} +inline ::mindspore::irpb::Explain* Event::mutable_explain() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Event.explain) + return _internal_mutable_explain(); +} + +inline bool Event::has_what() const { + return what_case() != WHAT_NOT_SET; +} +inline void Event::clear_has_what() { + _oneof_case_[0] = WHAT_NOT_SET; +} +inline Event::WhatCase Event::what_case() const { + return Event::WhatCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// LossLandscape_Point + +// optional .mindspore.irpb.TensorProto x = 1; +inline bool LossLandscape_Point::_internal_has_x() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || x_ != nullptr); + return value; +} +inline bool LossLandscape_Point::has_x() const { + return _internal_has_x(); +} +inline const ::mindspore::irpb::TensorProto& LossLandscape_Point::_internal_x() const { + const ::mindspore::irpb::TensorProto* p = x_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TensorProto_default_instance_); +} +inline const ::mindspore::irpb::TensorProto& LossLandscape_Point::x() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.Point.x) + return _internal_x(); +} +inline void LossLandscape_Point::unsafe_arena_set_allocated_x( + ::mindspore::irpb::TensorProto* x) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(x_); + } + x_ = x; + if (x) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.Point.x) +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::release_x() { + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::TensorProto* temp = x_; + x_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::unsafe_arena_release_x() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.Point.x) + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::TensorProto* temp = x_; + x_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::_internal_mutable_x() { + _has_bits_[0] |= 0x00000001u; + if (x_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TensorProto>(GetArena()); + x_ = p; + } + return x_; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::mutable_x() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.Point.x) + return _internal_mutable_x(); +} +inline void LossLandscape_Point::set_allocated_x(::mindspore::irpb::TensorProto* x) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete reinterpret_cast< ::PROTOBUF_NAMESPACE_ID::MessageLite*>(x_); + } + if (x) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(x)->GetArena(); + if (message_arena != submessage_arena) { + x = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, x, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + x_ = x; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.Point.x) +} + +// optional .mindspore.irpb.TensorProto y = 2; +inline bool LossLandscape_Point::_internal_has_y() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || y_ != nullptr); + return value; +} +inline bool LossLandscape_Point::has_y() const { + return _internal_has_y(); +} +inline const ::mindspore::irpb::TensorProto& LossLandscape_Point::_internal_y() const { + const ::mindspore::irpb::TensorProto* p = y_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TensorProto_default_instance_); +} +inline const ::mindspore::irpb::TensorProto& LossLandscape_Point::y() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.Point.y) + return _internal_y(); +} +inline void LossLandscape_Point::unsafe_arena_set_allocated_y( + ::mindspore::irpb::TensorProto* y) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(y_); + } + y_ = y; + if (y) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.Point.y) +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::release_y() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TensorProto* temp = y_; + y_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::unsafe_arena_release_y() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.Point.y) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::TensorProto* temp = y_; + y_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::_internal_mutable_y() { + _has_bits_[0] |= 0x00000002u; + if (y_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TensorProto>(GetArena()); + y_ = p; + } + return y_; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::mutable_y() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.Point.y) + return _internal_mutable_y(); +} +inline void LossLandscape_Point::set_allocated_y(::mindspore::irpb::TensorProto* y) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete reinterpret_cast< ::PROTOBUF_NAMESPACE_ID::MessageLite*>(y_); + } + if (y) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(y)->GetArena(); + if (message_arena != submessage_arena) { + y = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, y, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + y_ = y; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.Point.y) +} + +// optional .mindspore.irpb.TensorProto z = 3; +inline bool LossLandscape_Point::_internal_has_z() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || z_ != nullptr); + return value; +} +inline bool LossLandscape_Point::has_z() const { + return _internal_has_z(); +} +inline const ::mindspore::irpb::TensorProto& LossLandscape_Point::_internal_z() const { + const ::mindspore::irpb::TensorProto* p = z_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_TensorProto_default_instance_); +} +inline const ::mindspore::irpb::TensorProto& LossLandscape_Point::z() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.Point.z) + return _internal_z(); +} +inline void LossLandscape_Point::unsafe_arena_set_allocated_z( + ::mindspore::irpb::TensorProto* z) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(z_); + } + z_ = z; + if (z) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.Point.z) +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::release_z() { + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::TensorProto* temp = z_; + z_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::unsafe_arena_release_z() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.Point.z) + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::TensorProto* temp = z_; + z_ = nullptr; + return temp; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::_internal_mutable_z() { + _has_bits_[0] |= 0x00000004u; + if (z_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::TensorProto>(GetArena()); + z_ = p; + } + return z_; +} +inline ::mindspore::irpb::TensorProto* LossLandscape_Point::mutable_z() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.Point.z) + return _internal_mutable_z(); +} +inline void LossLandscape_Point::set_allocated_z(::mindspore::irpb::TensorProto* z) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete reinterpret_cast< ::PROTOBUF_NAMESPACE_ID::MessageLite*>(z_); + } + if (z) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(z)->GetArena(); + if (message_arena != submessage_arena) { + z = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, z, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + z_ = z; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.Point.z) +} + +// ------------------------------------------------------------------- + +// LossLandscape_LossPath + +// repeated int32 intervals = 1; +inline int LossLandscape_LossPath::_internal_intervals_size() const { + return intervals_.size(); +} +inline int LossLandscape_LossPath::intervals_size() const { + return _internal_intervals_size(); +} +inline void LossLandscape_LossPath::clear_intervals() { + intervals_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 LossLandscape_LossPath::_internal_intervals(int index) const { + return intervals_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 LossLandscape_LossPath::intervals(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.LossPath.intervals) + return _internal_intervals(index); +} +inline void LossLandscape_LossPath::set_intervals(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + intervals_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.LossLandscape.LossPath.intervals) +} +inline void LossLandscape_LossPath::_internal_add_intervals(::PROTOBUF_NAMESPACE_ID::int32 value) { + intervals_.Add(value); +} +inline void LossLandscape_LossPath::add_intervals(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_intervals(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.LossLandscape.LossPath.intervals) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +LossLandscape_LossPath::_internal_intervals() const { + return intervals_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +LossLandscape_LossPath::intervals() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.LossLandscape.LossPath.intervals) + return _internal_intervals(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +LossLandscape_LossPath::_internal_mutable_intervals() { + return &intervals_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +LossLandscape_LossPath::mutable_intervals() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.LossLandscape.LossPath.intervals) + return _internal_mutable_intervals(); +} + +// optional .mindspore.irpb.LossLandscape.Point points = 2; +inline bool LossLandscape_LossPath::_internal_has_points() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || points_ != nullptr); + return value; +} +inline bool LossLandscape_LossPath::has_points() const { + return _internal_has_points(); +} +inline void LossLandscape_LossPath::clear_points() { + if (points_ != nullptr) points_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::mindspore::irpb::LossLandscape_Point& LossLandscape_LossPath::_internal_points() const { + const ::mindspore::irpb::LossLandscape_Point* p = points_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_LossLandscape_Point_default_instance_); +} +inline const ::mindspore::irpb::LossLandscape_Point& LossLandscape_LossPath::points() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.LossPath.points) + return _internal_points(); +} +inline void LossLandscape_LossPath::unsafe_arena_set_allocated_points( + ::mindspore::irpb::LossLandscape_Point* points) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(points_); + } + points_ = points; + if (points) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.LossPath.points) +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape_LossPath::release_points() { + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::LossLandscape_Point* temp = points_; + points_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape_LossPath::unsafe_arena_release_points() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.LossPath.points) + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::LossLandscape_Point* temp = points_; + points_ = nullptr; + return temp; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape_LossPath::_internal_mutable_points() { + _has_bits_[0] |= 0x00000001u; + if (points_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::LossLandscape_Point>(GetArena()); + points_ = p; + } + return points_; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape_LossPath::mutable_points() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.LossPath.points) + return _internal_mutable_points(); +} +inline void LossLandscape_LossPath::set_allocated_points(::mindspore::irpb::LossLandscape_Point* points) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete points_; + } + if (points) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(points); + if (message_arena != submessage_arena) { + points = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, points, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + points_ = points; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.LossPath.points) +} + +// ------------------------------------------------------------------- + +// LossLandscape_Metadata + +// optional string decomposition = 1; +inline bool LossLandscape_Metadata::_internal_has_decomposition() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool LossLandscape_Metadata::has_decomposition() const { + return _internal_has_decomposition(); +} +inline void LossLandscape_Metadata::clear_decomposition() { + decomposition_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& LossLandscape_Metadata::decomposition() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.Metadata.decomposition) + return _internal_decomposition(); +} +inline void LossLandscape_Metadata::set_decomposition(const std::string& value) { + _internal_set_decomposition(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.LossLandscape.Metadata.decomposition) +} +inline std::string* LossLandscape_Metadata::mutable_decomposition() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.Metadata.decomposition) + return _internal_mutable_decomposition(); +} +inline const std::string& LossLandscape_Metadata::_internal_decomposition() const { + return decomposition_.Get(); +} +inline void LossLandscape_Metadata::_internal_set_decomposition(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + decomposition_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void LossLandscape_Metadata::set_decomposition(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + decomposition_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.LossLandscape.Metadata.decomposition) +} +inline void LossLandscape_Metadata::set_decomposition(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + decomposition_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.LossLandscape.Metadata.decomposition) +} +inline void LossLandscape_Metadata::set_decomposition(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + decomposition_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.LossLandscape.Metadata.decomposition) +} +inline std::string* LossLandscape_Metadata::_internal_mutable_decomposition() { + _has_bits_[0] |= 0x00000001u; + return decomposition_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* LossLandscape_Metadata::release_decomposition() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.Metadata.decomposition) + if (!_internal_has_decomposition()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return decomposition_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void LossLandscape_Metadata::set_allocated_decomposition(std::string* decomposition) { + if (decomposition != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + decomposition_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), decomposition, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.Metadata.decomposition) +} + +// optional string unit = 2; +inline bool LossLandscape_Metadata::_internal_has_unit() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool LossLandscape_Metadata::has_unit() const { + return _internal_has_unit(); +} +inline void LossLandscape_Metadata::clear_unit() { + unit_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& LossLandscape_Metadata::unit() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.Metadata.unit) + return _internal_unit(); +} +inline void LossLandscape_Metadata::set_unit(const std::string& value) { + _internal_set_unit(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.LossLandscape.Metadata.unit) +} +inline std::string* LossLandscape_Metadata::mutable_unit() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.Metadata.unit) + return _internal_mutable_unit(); +} +inline const std::string& LossLandscape_Metadata::_internal_unit() const { + return unit_.Get(); +} +inline void LossLandscape_Metadata::_internal_set_unit(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + unit_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void LossLandscape_Metadata::set_unit(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + unit_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.LossLandscape.Metadata.unit) +} +inline void LossLandscape_Metadata::set_unit(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + unit_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.LossLandscape.Metadata.unit) +} +inline void LossLandscape_Metadata::set_unit(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + unit_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.LossLandscape.Metadata.unit) +} +inline std::string* LossLandscape_Metadata::_internal_mutable_unit() { + _has_bits_[0] |= 0x00000002u; + return unit_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* LossLandscape_Metadata::release_unit() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.Metadata.unit) + if (!_internal_has_unit()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return unit_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void LossLandscape_Metadata::set_allocated_unit(std::string* unit) { + if (unit != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + unit_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), unit, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.Metadata.unit) +} + +// optional int32 step_per_epoch = 3; +inline bool LossLandscape_Metadata::_internal_has_step_per_epoch() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool LossLandscape_Metadata::has_step_per_epoch() const { + return _internal_has_step_per_epoch(); +} +inline void LossLandscape_Metadata::clear_step_per_epoch() { + step_per_epoch_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 LossLandscape_Metadata::_internal_step_per_epoch() const { + return step_per_epoch_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 LossLandscape_Metadata::step_per_epoch() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.Metadata.step_per_epoch) + return _internal_step_per_epoch(); +} +inline void LossLandscape_Metadata::_internal_set_step_per_epoch(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + step_per_epoch_ = value; +} +inline void LossLandscape_Metadata::set_step_per_epoch(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_step_per_epoch(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.LossLandscape.Metadata.step_per_epoch) +} + +// ------------------------------------------------------------------- + +// LossLandscape + +// optional .mindspore.irpb.LossLandscape.Point landscape = 1; +inline bool LossLandscape::_internal_has_landscape() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || landscape_ != nullptr); + return value; +} +inline bool LossLandscape::has_landscape() const { + return _internal_has_landscape(); +} +inline void LossLandscape::clear_landscape() { + if (landscape_ != nullptr) landscape_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::mindspore::irpb::LossLandscape_Point& LossLandscape::_internal_landscape() const { + const ::mindspore::irpb::LossLandscape_Point* p = landscape_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_LossLandscape_Point_default_instance_); +} +inline const ::mindspore::irpb::LossLandscape_Point& LossLandscape::landscape() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.landscape) + return _internal_landscape(); +} +inline void LossLandscape::unsafe_arena_set_allocated_landscape( + ::mindspore::irpb::LossLandscape_Point* landscape) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(landscape_); + } + landscape_ = landscape; + if (landscape) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.landscape) +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::release_landscape() { + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::LossLandscape_Point* temp = landscape_; + landscape_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::unsafe_arena_release_landscape() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.landscape) + _has_bits_[0] &= ~0x00000001u; + ::mindspore::irpb::LossLandscape_Point* temp = landscape_; + landscape_ = nullptr; + return temp; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::_internal_mutable_landscape() { + _has_bits_[0] |= 0x00000001u; + if (landscape_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::LossLandscape_Point>(GetArena()); + landscape_ = p; + } + return landscape_; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::mutable_landscape() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.landscape) + return _internal_mutable_landscape(); +} +inline void LossLandscape::set_allocated_landscape(::mindspore::irpb::LossLandscape_Point* landscape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete landscape_; + } + if (landscape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(landscape); + if (message_arena != submessage_arena) { + landscape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, landscape, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + landscape_ = landscape; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.landscape) +} + +// optional .mindspore.irpb.LossLandscape.LossPath loss_path = 2; +inline bool LossLandscape::_internal_has_loss_path() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + PROTOBUF_ASSUME(!value || loss_path_ != nullptr); + return value; +} +inline bool LossLandscape::has_loss_path() const { + return _internal_has_loss_path(); +} +inline void LossLandscape::clear_loss_path() { + if (loss_path_ != nullptr) loss_path_->Clear(); + _has_bits_[0] &= ~0x00000002u; +} +inline const ::mindspore::irpb::LossLandscape_LossPath& LossLandscape::_internal_loss_path() const { + const ::mindspore::irpb::LossLandscape_LossPath* p = loss_path_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_LossLandscape_LossPath_default_instance_); +} +inline const ::mindspore::irpb::LossLandscape_LossPath& LossLandscape::loss_path() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.loss_path) + return _internal_loss_path(); +} +inline void LossLandscape::unsafe_arena_set_allocated_loss_path( + ::mindspore::irpb::LossLandscape_LossPath* loss_path) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(loss_path_); + } + loss_path_ = loss_path; + if (loss_path) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.loss_path) +} +inline ::mindspore::irpb::LossLandscape_LossPath* LossLandscape::release_loss_path() { + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::LossLandscape_LossPath* temp = loss_path_; + loss_path_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::LossLandscape_LossPath* LossLandscape::unsafe_arena_release_loss_path() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.loss_path) + _has_bits_[0] &= ~0x00000002u; + ::mindspore::irpb::LossLandscape_LossPath* temp = loss_path_; + loss_path_ = nullptr; + return temp; +} +inline ::mindspore::irpb::LossLandscape_LossPath* LossLandscape::_internal_mutable_loss_path() { + _has_bits_[0] |= 0x00000002u; + if (loss_path_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::LossLandscape_LossPath>(GetArena()); + loss_path_ = p; + } + return loss_path_; +} +inline ::mindspore::irpb::LossLandscape_LossPath* LossLandscape::mutable_loss_path() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.loss_path) + return _internal_mutable_loss_path(); +} +inline void LossLandscape::set_allocated_loss_path(::mindspore::irpb::LossLandscape_LossPath* loss_path) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete loss_path_; + } + if (loss_path) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(loss_path); + if (message_arena != submessage_arena) { + loss_path = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, loss_path, submessage_arena); + } + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + loss_path_ = loss_path; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.loss_path) +} + +// optional .mindspore.irpb.LossLandscape.Metadata metadata = 3; +inline bool LossLandscape::_internal_has_metadata() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || metadata_ != nullptr); + return value; +} +inline bool LossLandscape::has_metadata() const { + return _internal_has_metadata(); +} +inline void LossLandscape::clear_metadata() { + if (metadata_ != nullptr) metadata_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::mindspore::irpb::LossLandscape_Metadata& LossLandscape::_internal_metadata() const { + const ::mindspore::irpb::LossLandscape_Metadata* p = metadata_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_LossLandscape_Metadata_default_instance_); +} +inline const ::mindspore::irpb::LossLandscape_Metadata& LossLandscape::metadata() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.metadata) + return _internal_metadata(); +} +inline void LossLandscape::unsafe_arena_set_allocated_metadata( + ::mindspore::irpb::LossLandscape_Metadata* metadata) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(metadata_); + } + metadata_ = metadata; + if (metadata) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.metadata) +} +inline ::mindspore::irpb::LossLandscape_Metadata* LossLandscape::release_metadata() { + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::LossLandscape_Metadata* temp = metadata_; + metadata_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::LossLandscape_Metadata* LossLandscape::unsafe_arena_release_metadata() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.metadata) + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::LossLandscape_Metadata* temp = metadata_; + metadata_ = nullptr; + return temp; +} +inline ::mindspore::irpb::LossLandscape_Metadata* LossLandscape::_internal_mutable_metadata() { + _has_bits_[0] |= 0x00000004u; + if (metadata_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::LossLandscape_Metadata>(GetArena()); + metadata_ = p; + } + return metadata_; +} +inline ::mindspore::irpb::LossLandscape_Metadata* LossLandscape::mutable_metadata() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.metadata) + return _internal_mutable_metadata(); +} +inline void LossLandscape::set_allocated_metadata(::mindspore::irpb::LossLandscape_Metadata* metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete metadata_; + } + if (metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(metadata); + if (message_arena != submessage_arena) { + metadata = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, metadata, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + metadata_ = metadata; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.metadata) +} + +// optional .mindspore.irpb.LossLandscape.Point convergence_point = 4; +inline bool LossLandscape::_internal_has_convergence_point() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || convergence_point_ != nullptr); + return value; +} +inline bool LossLandscape::has_convergence_point() const { + return _internal_has_convergence_point(); +} +inline void LossLandscape::clear_convergence_point() { + if (convergence_point_ != nullptr) convergence_point_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const ::mindspore::irpb::LossLandscape_Point& LossLandscape::_internal_convergence_point() const { + const ::mindspore::irpb::LossLandscape_Point* p = convergence_point_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_LossLandscape_Point_default_instance_); +} +inline const ::mindspore::irpb::LossLandscape_Point& LossLandscape::convergence_point() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.LossLandscape.convergence_point) + return _internal_convergence_point(); +} +inline void LossLandscape::unsafe_arena_set_allocated_convergence_point( + ::mindspore::irpb::LossLandscape_Point* convergence_point) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(convergence_point_); + } + convergence_point_ = convergence_point; + if (convergence_point) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.LossLandscape.convergence_point) +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::release_convergence_point() { + _has_bits_[0] &= ~0x00000008u; + ::mindspore::irpb::LossLandscape_Point* temp = convergence_point_; + convergence_point_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::unsafe_arena_release_convergence_point() { + // @@protoc_insertion_point(field_release:mindspore.irpb.LossLandscape.convergence_point) + _has_bits_[0] &= ~0x00000008u; + ::mindspore::irpb::LossLandscape_Point* temp = convergence_point_; + convergence_point_ = nullptr; + return temp; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::_internal_mutable_convergence_point() { + _has_bits_[0] |= 0x00000008u; + if (convergence_point_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::LossLandscape_Point>(GetArena()); + convergence_point_ = p; + } + return convergence_point_; +} +inline ::mindspore::irpb::LossLandscape_Point* LossLandscape::mutable_convergence_point() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.LossLandscape.convergence_point) + return _internal_mutable_convergence_point(); +} +inline void LossLandscape::set_allocated_convergence_point(::mindspore::irpb::LossLandscape_Point* convergence_point) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete convergence_point_; + } + if (convergence_point) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(convergence_point); + if (message_arena != submessage_arena) { + convergence_point = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, convergence_point, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + convergence_point_ = convergence_point; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.LossLandscape.convergence_point) +} + +// ------------------------------------------------------------------- + +// Summary_Image + +// required int32 height = 1; +inline bool Summary_Image::_internal_has_height() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Summary_Image::has_height() const { + return _internal_has_height(); +} +inline void Summary_Image::clear_height() { + height_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::_internal_height() const { + return height_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::height() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Image.height) + return _internal_height(); +} +inline void Summary_Image::_internal_set_height(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + height_ = value; +} +inline void Summary_Image::set_height(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_height(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Image.height) +} + +// required int32 width = 2; +inline bool Summary_Image::_internal_has_width() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Summary_Image::has_width() const { + return _internal_has_width(); +} +inline void Summary_Image::clear_width() { + width_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::_internal_width() const { + return width_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::width() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Image.width) + return _internal_width(); +} +inline void Summary_Image::_internal_set_width(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + width_ = value; +} +inline void Summary_Image::set_width(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_width(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Image.width) +} + +// required int32 colorspace = 3; +inline bool Summary_Image::_internal_has_colorspace() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool Summary_Image::has_colorspace() const { + return _internal_has_colorspace(); +} +inline void Summary_Image::clear_colorspace() { + colorspace_ = 0; + _has_bits_[0] &= ~0x00000008u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::_internal_colorspace() const { + return colorspace_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::colorspace() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Image.colorspace) + return _internal_colorspace(); +} +inline void Summary_Image::_internal_set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000008u; + colorspace_ = value; +} +inline void Summary_Image::set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_colorspace(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Image.colorspace) +} + +// required bytes encoded_image = 4; +inline bool Summary_Image::_internal_has_encoded_image() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Summary_Image::has_encoded_image() const { + return _internal_has_encoded_image(); +} +inline void Summary_Image::clear_encoded_image() { + encoded_image_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Summary_Image::encoded_image() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Image.encoded_image) + return _internal_encoded_image(); +} +inline void Summary_Image::set_encoded_image(const std::string& value) { + _internal_set_encoded_image(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Image.encoded_image) +} +inline std::string* Summary_Image::mutable_encoded_image() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Image.encoded_image) + return _internal_mutable_encoded_image(); +} +inline const std::string& Summary_Image::_internal_encoded_image() const { + return encoded_image_.Get(); +} +inline void Summary_Image::_internal_set_encoded_image(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + encoded_image_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Image::set_encoded_image(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + encoded_image_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Summary.Image.encoded_image) +} +inline void Summary_Image::set_encoded_image(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + encoded_image_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Summary.Image.encoded_image) +} +inline void Summary_Image::set_encoded_image(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + encoded_image_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Summary.Image.encoded_image) +} +inline std::string* Summary_Image::_internal_mutable_encoded_image() { + _has_bits_[0] |= 0x00000001u; + return encoded_image_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Image::release_encoded_image() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Summary.Image.encoded_image) + if (!_internal_has_encoded_image()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return encoded_image_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Image::set_allocated_encoded_image(std::string* encoded_image) { + if (encoded_image != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + encoded_image_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), encoded_image, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Summary.Image.encoded_image) +} + +// ------------------------------------------------------------------- + +// Summary_Histogram_bucket + +// required double left = 1; +inline bool Summary_Histogram_bucket::_internal_has_left() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Summary_Histogram_bucket::has_left() const { + return _internal_has_left(); +} +inline void Summary_Histogram_bucket::clear_left() { + left_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline double Summary_Histogram_bucket::_internal_left() const { + return left_; +} +inline double Summary_Histogram_bucket::left() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.bucket.left) + return _internal_left(); +} +inline void Summary_Histogram_bucket::_internal_set_left(double value) { + _has_bits_[0] |= 0x00000001u; + left_ = value; +} +inline void Summary_Histogram_bucket::set_left(double value) { + _internal_set_left(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.bucket.left) +} + +// required double width = 2; +inline bool Summary_Histogram_bucket::_internal_has_width() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Summary_Histogram_bucket::has_width() const { + return _internal_has_width(); +} +inline void Summary_Histogram_bucket::clear_width() { + width_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline double Summary_Histogram_bucket::_internal_width() const { + return width_; +} +inline double Summary_Histogram_bucket::width() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.bucket.width) + return _internal_width(); +} +inline void Summary_Histogram_bucket::_internal_set_width(double value) { + _has_bits_[0] |= 0x00000002u; + width_ = value; +} +inline void Summary_Histogram_bucket::set_width(double value) { + _internal_set_width(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.bucket.width) +} + +// required int64 count = 3; +inline bool Summary_Histogram_bucket::_internal_has_count() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Summary_Histogram_bucket::has_count() const { + return _internal_has_count(); +} +inline void Summary_Histogram_bucket::clear_count() { + count_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram_bucket::_internal_count() const { + return count_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram_bucket::count() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.bucket.count) + return _internal_count(); +} +inline void Summary_Histogram_bucket::_internal_set_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000004u; + count_ = value; +} +inline void Summary_Histogram_bucket::set_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_count(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.bucket.count) +} + +// ------------------------------------------------------------------- + +// Summary_Histogram + +// repeated .mindspore.irpb.Summary.Histogram.bucket buckets = 1; +inline int Summary_Histogram::_internal_buckets_size() const { + return buckets_.size(); +} +inline int Summary_Histogram::buckets_size() const { + return _internal_buckets_size(); +} +inline void Summary_Histogram::clear_buckets() { + buckets_.Clear(); +} +inline ::mindspore::irpb::Summary_Histogram_bucket* Summary_Histogram::mutable_buckets(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Histogram.buckets) + return buckets_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Histogram_bucket >* +Summary_Histogram::mutable_buckets() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Summary.Histogram.buckets) + return &buckets_; +} +inline const ::mindspore::irpb::Summary_Histogram_bucket& Summary_Histogram::_internal_buckets(int index) const { + return buckets_.Get(index); +} +inline const ::mindspore::irpb::Summary_Histogram_bucket& Summary_Histogram::buckets(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.buckets) + return _internal_buckets(index); +} +inline ::mindspore::irpb::Summary_Histogram_bucket* Summary_Histogram::_internal_add_buckets() { + return buckets_.Add(); +} +inline ::mindspore::irpb::Summary_Histogram_bucket* Summary_Histogram::add_buckets() { + // @@protoc_insertion_point(field_add:mindspore.irpb.Summary.Histogram.buckets) + return _internal_add_buckets(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Histogram_bucket >& +Summary_Histogram::buckets() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Summary.Histogram.buckets) + return buckets_; +} + +// optional int64 nan_count = 2; +inline bool Summary_Histogram::_internal_has_nan_count() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Summary_Histogram::has_nan_count() const { + return _internal_has_nan_count(); +} +inline void Summary_Histogram::clear_nan_count() { + nan_count_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000001u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::_internal_nan_count() const { + return nan_count_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::nan_count() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.nan_count) + return _internal_nan_count(); +} +inline void Summary_Histogram::_internal_set_nan_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000001u; + nan_count_ = value; +} +inline void Summary_Histogram::set_nan_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_nan_count(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.nan_count) +} + +// optional int64 pos_inf_count = 3; +inline bool Summary_Histogram::_internal_has_pos_inf_count() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Summary_Histogram::has_pos_inf_count() const { + return _internal_has_pos_inf_count(); +} +inline void Summary_Histogram::clear_pos_inf_count() { + pos_inf_count_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::_internal_pos_inf_count() const { + return pos_inf_count_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::pos_inf_count() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.pos_inf_count) + return _internal_pos_inf_count(); +} +inline void Summary_Histogram::_internal_set_pos_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + pos_inf_count_ = value; +} +inline void Summary_Histogram::set_pos_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_pos_inf_count(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.pos_inf_count) +} + +// optional int64 neg_inf_count = 4; +inline bool Summary_Histogram::_internal_has_neg_inf_count() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Summary_Histogram::has_neg_inf_count() const { + return _internal_has_neg_inf_count(); +} +inline void Summary_Histogram::clear_neg_inf_count() { + neg_inf_count_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::_internal_neg_inf_count() const { + return neg_inf_count_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::neg_inf_count() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.neg_inf_count) + return _internal_neg_inf_count(); +} +inline void Summary_Histogram::_internal_set_neg_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000004u; + neg_inf_count_ = value; +} +inline void Summary_Histogram::set_neg_inf_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_neg_inf_count(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.neg_inf_count) +} + +// optional double max = 5; +inline bool Summary_Histogram::_internal_has_max() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool Summary_Histogram::has_max() const { + return _internal_has_max(); +} +inline void Summary_Histogram::clear_max() { + max_ = 0; + _has_bits_[0] &= ~0x00000008u; +} +inline double Summary_Histogram::_internal_max() const { + return max_; +} +inline double Summary_Histogram::max() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.max) + return _internal_max(); +} +inline void Summary_Histogram::_internal_set_max(double value) { + _has_bits_[0] |= 0x00000008u; + max_ = value; +} +inline void Summary_Histogram::set_max(double value) { + _internal_set_max(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.max) +} + +// optional double min = 6; +inline bool Summary_Histogram::_internal_has_min() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool Summary_Histogram::has_min() const { + return _internal_has_min(); +} +inline void Summary_Histogram::clear_min() { + min_ = 0; + _has_bits_[0] &= ~0x00000010u; +} +inline double Summary_Histogram::_internal_min() const { + return min_; +} +inline double Summary_Histogram::min() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.min) + return _internal_min(); +} +inline void Summary_Histogram::_internal_set_min(double value) { + _has_bits_[0] |= 0x00000010u; + min_ = value; +} +inline void Summary_Histogram::set_min(double value) { + _internal_set_min(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.min) +} + +// optional double sum = 7; +inline bool Summary_Histogram::_internal_has_sum() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool Summary_Histogram::has_sum() const { + return _internal_has_sum(); +} +inline void Summary_Histogram::clear_sum() { + sum_ = 0; + _has_bits_[0] &= ~0x00000020u; +} +inline double Summary_Histogram::_internal_sum() const { + return sum_; +} +inline double Summary_Histogram::sum() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.sum) + return _internal_sum(); +} +inline void Summary_Histogram::_internal_set_sum(double value) { + _has_bits_[0] |= 0x00000020u; + sum_ = value; +} +inline void Summary_Histogram::set_sum(double value) { + _internal_set_sum(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.sum) +} + +// optional int64 count = 8; +inline bool Summary_Histogram::_internal_has_count() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool Summary_Histogram::has_count() const { + return _internal_has_count(); +} +inline void Summary_Histogram::clear_count() { + count_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::_internal_count() const { + return count_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Histogram::count() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Histogram.count) + return _internal_count(); +} +inline void Summary_Histogram::_internal_set_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000040u; + count_ = value; +} +inline void Summary_Histogram::set_count(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_count(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Histogram.count) +} + +// ------------------------------------------------------------------- + +// Summary_Value + +// required string tag = 1; +inline bool Summary_Value::_internal_has_tag() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Summary_Value::has_tag() const { + return _internal_has_tag(); +} +inline void Summary_Value::clear_tag() { + tag_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Summary_Value::tag() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Value.tag) + return _internal_tag(); +} +inline void Summary_Value::set_tag(const std::string& value) { + _internal_set_tag(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Value.tag) +} +inline std::string* Summary_Value::mutable_tag() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Value.tag) + return _internal_mutable_tag(); +} +inline const std::string& Summary_Value::_internal_tag() const { + return tag_.Get(); +} +inline void Summary_Value::_internal_set_tag(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Value::set_tag(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + tag_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Summary.Value.tag) +} +inline void Summary_Value::set_tag(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Summary.Value.tag) +} +inline void Summary_Value::set_tag(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Summary.Value.tag) +} +inline std::string* Summary_Value::_internal_mutable_tag() { + _has_bits_[0] |= 0x00000001u; + return tag_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Value::release_tag() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Summary.Value.tag) + if (!_internal_has_tag()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return tag_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Value::set_allocated_tag(std::string* tag) { + if (tag != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + tag_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), tag, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Summary.Value.tag) +} + +// float scalar_value = 3; +inline bool Summary_Value::_internal_has_scalar_value() const { + return value_case() == kScalarValue; +} +inline bool Summary_Value::has_scalar_value() const { + return _internal_has_scalar_value(); +} +inline void Summary_Value::set_has_scalar_value() { + _oneof_case_[0] = kScalarValue; +} +inline void Summary_Value::clear_scalar_value() { + if (_internal_has_scalar_value()) { + value_.scalar_value_ = 0; + clear_has_value(); + } +} +inline float Summary_Value::_internal_scalar_value() const { + if (_internal_has_scalar_value()) { + return value_.scalar_value_; + } + return 0; +} +inline void Summary_Value::_internal_set_scalar_value(float value) { + if (!_internal_has_scalar_value()) { + clear_value(); + set_has_scalar_value(); + } + value_.scalar_value_ = value; +} +inline float Summary_Value::scalar_value() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Value.scalar_value) + return _internal_scalar_value(); +} +inline void Summary_Value::set_scalar_value(float value) { + _internal_set_scalar_value(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Summary.Value.scalar_value) +} + +// .mindspore.irpb.Summary.Image image = 4; +inline bool Summary_Value::_internal_has_image() const { + return value_case() == kImage; +} +inline bool Summary_Value::has_image() const { + return _internal_has_image(); +} +inline void Summary_Value::set_has_image() { + _oneof_case_[0] = kImage; +} +inline void Summary_Value::clear_image() { + if (_internal_has_image()) { + if (GetArena() == nullptr) { + delete value_.image_; + } + clear_has_value(); + } +} +inline ::mindspore::irpb::Summary_Image* Summary_Value::release_image() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Summary.Value.image) + if (_internal_has_image()) { + clear_has_value(); + ::mindspore::irpb::Summary_Image* temp = value_.image_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.image_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::Summary_Image& Summary_Value::_internal_image() const { + return _internal_has_image() + ? *value_.image_ + : *reinterpret_cast< ::mindspore::irpb::Summary_Image*>(&::mindspore::irpb::_Summary_Image_default_instance_); +} +inline const ::mindspore::irpb::Summary_Image& Summary_Value::image() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Value.image) + return _internal_image(); +} +inline ::mindspore::irpb::Summary_Image* Summary_Value::unsafe_arena_release_image() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Summary.Value.image) + if (_internal_has_image()) { + clear_has_value(); + ::mindspore::irpb::Summary_Image* temp = value_.image_; + value_.image_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_image(::mindspore::irpb::Summary_Image* image) { + clear_value(); + if (image) { + set_has_image(); + value_.image_ = image; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Summary.Value.image) +} +inline ::mindspore::irpb::Summary_Image* Summary_Value::_internal_mutable_image() { + if (!_internal_has_image()) { + clear_value(); + set_has_image(); + value_.image_ = CreateMaybeMessage< ::mindspore::irpb::Summary_Image >(GetArena()); + } + return value_.image_; +} +inline ::mindspore::irpb::Summary_Image* Summary_Value::mutable_image() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Value.image) + return _internal_mutable_image(); +} + +// .mindspore.irpb.TensorProto tensor = 8; +inline bool Summary_Value::_internal_has_tensor() const { + return value_case() == kTensor; +} +inline bool Summary_Value::has_tensor() const { + return _internal_has_tensor(); +} +inline void Summary_Value::set_has_tensor() { + _oneof_case_[0] = kTensor; +} +inline ::mindspore::irpb::TensorProto* Summary_Value::release_tensor() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Summary.Value.tensor) + if (_internal_has_tensor()) { + clear_has_value(); + ::mindspore::irpb::TensorProto* temp = value_.tensor_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.tensor_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::TensorProto& Summary_Value::_internal_tensor() const { + return _internal_has_tensor() + ? *value_.tensor_ + : *reinterpret_cast< ::mindspore::irpb::TensorProto*>(&::mindspore::irpb::_TensorProto_default_instance_); +} +inline const ::mindspore::irpb::TensorProto& Summary_Value::tensor() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Value.tensor) + return _internal_tensor(); +} +inline ::mindspore::irpb::TensorProto* Summary_Value::unsafe_arena_release_tensor() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Summary.Value.tensor) + if (_internal_has_tensor()) { + clear_has_value(); + ::mindspore::irpb::TensorProto* temp = value_.tensor_; + value_.tensor_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_tensor(::mindspore::irpb::TensorProto* tensor) { + clear_value(); + if (tensor) { + set_has_tensor(); + value_.tensor_ = tensor; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Summary.Value.tensor) +} +inline ::mindspore::irpb::TensorProto* Summary_Value::_internal_mutable_tensor() { + if (!_internal_has_tensor()) { + clear_value(); + set_has_tensor(); + value_.tensor_ = CreateMaybeMessage< ::mindspore::irpb::TensorProto >(GetArena()); + } + return value_.tensor_; +} +inline ::mindspore::irpb::TensorProto* Summary_Value::mutable_tensor() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Value.tensor) + return _internal_mutable_tensor(); +} + +// .mindspore.irpb.Summary.Histogram histogram = 9; +inline bool Summary_Value::_internal_has_histogram() const { + return value_case() == kHistogram; +} +inline bool Summary_Value::has_histogram() const { + return _internal_has_histogram(); +} +inline void Summary_Value::set_has_histogram() { + _oneof_case_[0] = kHistogram; +} +inline void Summary_Value::clear_histogram() { + if (_internal_has_histogram()) { + if (GetArena() == nullptr) { + delete value_.histogram_; + } + clear_has_value(); + } +} +inline ::mindspore::irpb::Summary_Histogram* Summary_Value::release_histogram() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Summary.Value.histogram) + if (_internal_has_histogram()) { + clear_has_value(); + ::mindspore::irpb::Summary_Histogram* temp = value_.histogram_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.histogram_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::Summary_Histogram& Summary_Value::_internal_histogram() const { + return _internal_has_histogram() + ? *value_.histogram_ + : *reinterpret_cast< ::mindspore::irpb::Summary_Histogram*>(&::mindspore::irpb::_Summary_Histogram_default_instance_); +} +inline const ::mindspore::irpb::Summary_Histogram& Summary_Value::histogram() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Value.histogram) + return _internal_histogram(); +} +inline ::mindspore::irpb::Summary_Histogram* Summary_Value::unsafe_arena_release_histogram() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Summary.Value.histogram) + if (_internal_has_histogram()) { + clear_has_value(); + ::mindspore::irpb::Summary_Histogram* temp = value_.histogram_; + value_.histogram_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_histogram(::mindspore::irpb::Summary_Histogram* histogram) { + clear_value(); + if (histogram) { + set_has_histogram(); + value_.histogram_ = histogram; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Summary.Value.histogram) +} +inline ::mindspore::irpb::Summary_Histogram* Summary_Value::_internal_mutable_histogram() { + if (!_internal_has_histogram()) { + clear_value(); + set_has_histogram(); + value_.histogram_ = CreateMaybeMessage< ::mindspore::irpb::Summary_Histogram >(GetArena()); + } + return value_.histogram_; +} +inline ::mindspore::irpb::Summary_Histogram* Summary_Value::mutable_histogram() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Value.histogram) + return _internal_mutable_histogram(); +} + +// .mindspore.irpb.LossLandscape loss_landscape = 10; +inline bool Summary_Value::_internal_has_loss_landscape() const { + return value_case() == kLossLandscape; +} +inline bool Summary_Value::has_loss_landscape() const { + return _internal_has_loss_landscape(); +} +inline void Summary_Value::set_has_loss_landscape() { + _oneof_case_[0] = kLossLandscape; +} +inline void Summary_Value::clear_loss_landscape() { + if (_internal_has_loss_landscape()) { + if (GetArena() == nullptr) { + delete value_.loss_landscape_; + } + clear_has_value(); + } +} +inline ::mindspore::irpb::LossLandscape* Summary_Value::release_loss_landscape() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Summary.Value.loss_landscape) + if (_internal_has_loss_landscape()) { + clear_has_value(); + ::mindspore::irpb::LossLandscape* temp = value_.loss_landscape_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.loss_landscape_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::mindspore::irpb::LossLandscape& Summary_Value::_internal_loss_landscape() const { + return _internal_has_loss_landscape() + ? *value_.loss_landscape_ + : *reinterpret_cast< ::mindspore::irpb::LossLandscape*>(&::mindspore::irpb::_LossLandscape_default_instance_); +} +inline const ::mindspore::irpb::LossLandscape& Summary_Value::loss_landscape() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.Value.loss_landscape) + return _internal_loss_landscape(); +} +inline ::mindspore::irpb::LossLandscape* Summary_Value::unsafe_arena_release_loss_landscape() { + // @@protoc_insertion_point(field_unsafe_arena_release:mindspore.irpb.Summary.Value.loss_landscape) + if (_internal_has_loss_landscape()) { + clear_has_value(); + ::mindspore::irpb::LossLandscape* temp = value_.loss_landscape_; + value_.loss_landscape_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_loss_landscape(::mindspore::irpb::LossLandscape* loss_landscape) { + clear_value(); + if (loss_landscape) { + set_has_loss_landscape(); + value_.loss_landscape_ = loss_landscape; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Summary.Value.loss_landscape) +} +inline ::mindspore::irpb::LossLandscape* Summary_Value::_internal_mutable_loss_landscape() { + if (!_internal_has_loss_landscape()) { + clear_value(); + set_has_loss_landscape(); + value_.loss_landscape_ = CreateMaybeMessage< ::mindspore::irpb::LossLandscape >(GetArena()); + } + return value_.loss_landscape_; +} +inline ::mindspore::irpb::LossLandscape* Summary_Value::mutable_loss_landscape() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.Value.loss_landscape) + return _internal_mutable_loss_landscape(); +} + +inline bool Summary_Value::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void Summary_Value::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline Summary_Value::ValueCase Summary_Value::value_case() const { + return Summary_Value::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// Summary + +// repeated .mindspore.irpb.Summary.Value value = 1; +inline int Summary::_internal_value_size() const { + return value_.size(); +} +inline int Summary::value_size() const { + return _internal_value_size(); +} +inline void Summary::clear_value() { + value_.Clear(); +} +inline ::mindspore::irpb::Summary_Value* Summary::mutable_value(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Summary.value) + return value_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Value >* +Summary::mutable_value() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Summary.value) + return &value_; +} +inline const ::mindspore::irpb::Summary_Value& Summary::_internal_value(int index) const { + return value_.Get(index); +} +inline const ::mindspore::irpb::Summary_Value& Summary::value(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Summary.value) + return _internal_value(index); +} +inline ::mindspore::irpb::Summary_Value* Summary::_internal_add_value() { + return value_.Add(); +} +inline ::mindspore::irpb::Summary_Value* Summary::add_value() { + // @@protoc_insertion_point(field_add:mindspore.irpb.Summary.value) + return _internal_add_value(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Summary_Value >& +Summary::value() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Summary.value) + return value_; +} + +// ------------------------------------------------------------------- + +// Explain_Inference + +// repeated float ground_truth_prob = 1; +inline int Explain_Inference::_internal_ground_truth_prob_size() const { + return ground_truth_prob_.size(); +} +inline int Explain_Inference::ground_truth_prob_size() const { + return _internal_ground_truth_prob_size(); +} +inline void Explain_Inference::clear_ground_truth_prob() { + ground_truth_prob_.Clear(); +} +inline float Explain_Inference::_internal_ground_truth_prob(int index) const { + return ground_truth_prob_.Get(index); +} +inline float Explain_Inference::ground_truth_prob(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.ground_truth_prob) + return _internal_ground_truth_prob(index); +} +inline void Explain_Inference::set_ground_truth_prob(int index, float value) { + ground_truth_prob_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.ground_truth_prob) +} +inline void Explain_Inference::_internal_add_ground_truth_prob(float value) { + ground_truth_prob_.Add(value); +} +inline void Explain_Inference::add_ground_truth_prob(float value) { + _internal_add_ground_truth_prob(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.ground_truth_prob) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_ground_truth_prob() const { + return ground_truth_prob_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::ground_truth_prob() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.ground_truth_prob) + return _internal_ground_truth_prob(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_ground_truth_prob() { + return &ground_truth_prob_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_ground_truth_prob() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.ground_truth_prob) + return _internal_mutable_ground_truth_prob(); +} + +// repeated int32 predicted_label = 2; +inline int Explain_Inference::_internal_predicted_label_size() const { + return predicted_label_.size(); +} +inline int Explain_Inference::predicted_label_size() const { + return _internal_predicted_label_size(); +} +inline void Explain_Inference::clear_predicted_label() { + predicted_label_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_Inference::_internal_predicted_label(int index) const { + return predicted_label_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_Inference::predicted_label(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.predicted_label) + return _internal_predicted_label(index); +} +inline void Explain_Inference::set_predicted_label(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + predicted_label_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.predicted_label) +} +inline void Explain_Inference::_internal_add_predicted_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + predicted_label_.Add(value); +} +inline void Explain_Inference::add_predicted_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_predicted_label(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.predicted_label) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +Explain_Inference::_internal_predicted_label() const { + return predicted_label_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +Explain_Inference::predicted_label() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.predicted_label) + return _internal_predicted_label(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +Explain_Inference::_internal_mutable_predicted_label() { + return &predicted_label_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +Explain_Inference::mutable_predicted_label() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.predicted_label) + return _internal_mutable_predicted_label(); +} + +// repeated float predicted_prob = 3; +inline int Explain_Inference::_internal_predicted_prob_size() const { + return predicted_prob_.size(); +} +inline int Explain_Inference::predicted_prob_size() const { + return _internal_predicted_prob_size(); +} +inline void Explain_Inference::clear_predicted_prob() { + predicted_prob_.Clear(); +} +inline float Explain_Inference::_internal_predicted_prob(int index) const { + return predicted_prob_.Get(index); +} +inline float Explain_Inference::predicted_prob(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.predicted_prob) + return _internal_predicted_prob(index); +} +inline void Explain_Inference::set_predicted_prob(int index, float value) { + predicted_prob_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.predicted_prob) +} +inline void Explain_Inference::_internal_add_predicted_prob(float value) { + predicted_prob_.Add(value); +} +inline void Explain_Inference::add_predicted_prob(float value) { + _internal_add_predicted_prob(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.predicted_prob) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_predicted_prob() const { + return predicted_prob_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::predicted_prob() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.predicted_prob) + return _internal_predicted_prob(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_predicted_prob() { + return &predicted_prob_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_predicted_prob() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.predicted_prob) + return _internal_mutable_predicted_prob(); +} + +// repeated float ground_truth_prob_sd = 4; +inline int Explain_Inference::_internal_ground_truth_prob_sd_size() const { + return ground_truth_prob_sd_.size(); +} +inline int Explain_Inference::ground_truth_prob_sd_size() const { + return _internal_ground_truth_prob_sd_size(); +} +inline void Explain_Inference::clear_ground_truth_prob_sd() { + ground_truth_prob_sd_.Clear(); +} +inline float Explain_Inference::_internal_ground_truth_prob_sd(int index) const { + return ground_truth_prob_sd_.Get(index); +} +inline float Explain_Inference::ground_truth_prob_sd(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.ground_truth_prob_sd) + return _internal_ground_truth_prob_sd(index); +} +inline void Explain_Inference::set_ground_truth_prob_sd(int index, float value) { + ground_truth_prob_sd_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.ground_truth_prob_sd) +} +inline void Explain_Inference::_internal_add_ground_truth_prob_sd(float value) { + ground_truth_prob_sd_.Add(value); +} +inline void Explain_Inference::add_ground_truth_prob_sd(float value) { + _internal_add_ground_truth_prob_sd(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.ground_truth_prob_sd) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_ground_truth_prob_sd() const { + return ground_truth_prob_sd_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::ground_truth_prob_sd() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.ground_truth_prob_sd) + return _internal_ground_truth_prob_sd(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_ground_truth_prob_sd() { + return &ground_truth_prob_sd_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_ground_truth_prob_sd() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.ground_truth_prob_sd) + return _internal_mutable_ground_truth_prob_sd(); +} + +// repeated float ground_truth_prob_itl95_low = 5; +inline int Explain_Inference::_internal_ground_truth_prob_itl95_low_size() const { + return ground_truth_prob_itl95_low_.size(); +} +inline int Explain_Inference::ground_truth_prob_itl95_low_size() const { + return _internal_ground_truth_prob_itl95_low_size(); +} +inline void Explain_Inference::clear_ground_truth_prob_itl95_low() { + ground_truth_prob_itl95_low_.Clear(); +} +inline float Explain_Inference::_internal_ground_truth_prob_itl95_low(int index) const { + return ground_truth_prob_itl95_low_.Get(index); +} +inline float Explain_Inference::ground_truth_prob_itl95_low(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_low) + return _internal_ground_truth_prob_itl95_low(index); +} +inline void Explain_Inference::set_ground_truth_prob_itl95_low(int index, float value) { + ground_truth_prob_itl95_low_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_low) +} +inline void Explain_Inference::_internal_add_ground_truth_prob_itl95_low(float value) { + ground_truth_prob_itl95_low_.Add(value); +} +inline void Explain_Inference::add_ground_truth_prob_itl95_low(float value) { + _internal_add_ground_truth_prob_itl95_low(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_low) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_ground_truth_prob_itl95_low() const { + return ground_truth_prob_itl95_low_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::ground_truth_prob_itl95_low() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_low) + return _internal_ground_truth_prob_itl95_low(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_ground_truth_prob_itl95_low() { + return &ground_truth_prob_itl95_low_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_ground_truth_prob_itl95_low() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_low) + return _internal_mutable_ground_truth_prob_itl95_low(); +} + +// repeated float ground_truth_prob_itl95_hi = 6; +inline int Explain_Inference::_internal_ground_truth_prob_itl95_hi_size() const { + return ground_truth_prob_itl95_hi_.size(); +} +inline int Explain_Inference::ground_truth_prob_itl95_hi_size() const { + return _internal_ground_truth_prob_itl95_hi_size(); +} +inline void Explain_Inference::clear_ground_truth_prob_itl95_hi() { + ground_truth_prob_itl95_hi_.Clear(); +} +inline float Explain_Inference::_internal_ground_truth_prob_itl95_hi(int index) const { + return ground_truth_prob_itl95_hi_.Get(index); +} +inline float Explain_Inference::ground_truth_prob_itl95_hi(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_hi) + return _internal_ground_truth_prob_itl95_hi(index); +} +inline void Explain_Inference::set_ground_truth_prob_itl95_hi(int index, float value) { + ground_truth_prob_itl95_hi_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_hi) +} +inline void Explain_Inference::_internal_add_ground_truth_prob_itl95_hi(float value) { + ground_truth_prob_itl95_hi_.Add(value); +} +inline void Explain_Inference::add_ground_truth_prob_itl95_hi(float value) { + _internal_add_ground_truth_prob_itl95_hi(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_hi) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_ground_truth_prob_itl95_hi() const { + return ground_truth_prob_itl95_hi_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::ground_truth_prob_itl95_hi() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_hi) + return _internal_ground_truth_prob_itl95_hi(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_ground_truth_prob_itl95_hi() { + return &ground_truth_prob_itl95_hi_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_ground_truth_prob_itl95_hi() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.ground_truth_prob_itl95_hi) + return _internal_mutable_ground_truth_prob_itl95_hi(); +} + +// repeated float predicted_prob_sd = 7; +inline int Explain_Inference::_internal_predicted_prob_sd_size() const { + return predicted_prob_sd_.size(); +} +inline int Explain_Inference::predicted_prob_sd_size() const { + return _internal_predicted_prob_sd_size(); +} +inline void Explain_Inference::clear_predicted_prob_sd() { + predicted_prob_sd_.Clear(); +} +inline float Explain_Inference::_internal_predicted_prob_sd(int index) const { + return predicted_prob_sd_.Get(index); +} +inline float Explain_Inference::predicted_prob_sd(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.predicted_prob_sd) + return _internal_predicted_prob_sd(index); +} +inline void Explain_Inference::set_predicted_prob_sd(int index, float value) { + predicted_prob_sd_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.predicted_prob_sd) +} +inline void Explain_Inference::_internal_add_predicted_prob_sd(float value) { + predicted_prob_sd_.Add(value); +} +inline void Explain_Inference::add_predicted_prob_sd(float value) { + _internal_add_predicted_prob_sd(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.predicted_prob_sd) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_predicted_prob_sd() const { + return predicted_prob_sd_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::predicted_prob_sd() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.predicted_prob_sd) + return _internal_predicted_prob_sd(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_predicted_prob_sd() { + return &predicted_prob_sd_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_predicted_prob_sd() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.predicted_prob_sd) + return _internal_mutable_predicted_prob_sd(); +} + +// repeated float predicted_prob_itl95_low = 8; +inline int Explain_Inference::_internal_predicted_prob_itl95_low_size() const { + return predicted_prob_itl95_low_.size(); +} +inline int Explain_Inference::predicted_prob_itl95_low_size() const { + return _internal_predicted_prob_itl95_low_size(); +} +inline void Explain_Inference::clear_predicted_prob_itl95_low() { + predicted_prob_itl95_low_.Clear(); +} +inline float Explain_Inference::_internal_predicted_prob_itl95_low(int index) const { + return predicted_prob_itl95_low_.Get(index); +} +inline float Explain_Inference::predicted_prob_itl95_low(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.predicted_prob_itl95_low) + return _internal_predicted_prob_itl95_low(index); +} +inline void Explain_Inference::set_predicted_prob_itl95_low(int index, float value) { + predicted_prob_itl95_low_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.predicted_prob_itl95_low) +} +inline void Explain_Inference::_internal_add_predicted_prob_itl95_low(float value) { + predicted_prob_itl95_low_.Add(value); +} +inline void Explain_Inference::add_predicted_prob_itl95_low(float value) { + _internal_add_predicted_prob_itl95_low(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.predicted_prob_itl95_low) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_predicted_prob_itl95_low() const { + return predicted_prob_itl95_low_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::predicted_prob_itl95_low() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.predicted_prob_itl95_low) + return _internal_predicted_prob_itl95_low(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_predicted_prob_itl95_low() { + return &predicted_prob_itl95_low_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_predicted_prob_itl95_low() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.predicted_prob_itl95_low) + return _internal_mutable_predicted_prob_itl95_low(); +} + +// repeated float predicted_prob_itl95_hi = 9; +inline int Explain_Inference::_internal_predicted_prob_itl95_hi_size() const { + return predicted_prob_itl95_hi_.size(); +} +inline int Explain_Inference::predicted_prob_itl95_hi_size() const { + return _internal_predicted_prob_itl95_hi_size(); +} +inline void Explain_Inference::clear_predicted_prob_itl95_hi() { + predicted_prob_itl95_hi_.Clear(); +} +inline float Explain_Inference::_internal_predicted_prob_itl95_hi(int index) const { + return predicted_prob_itl95_hi_.Get(index); +} +inline float Explain_Inference::predicted_prob_itl95_hi(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Inference.predicted_prob_itl95_hi) + return _internal_predicted_prob_itl95_hi(index); +} +inline void Explain_Inference::set_predicted_prob_itl95_hi(int index, float value) { + predicted_prob_itl95_hi_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Inference.predicted_prob_itl95_hi) +} +inline void Explain_Inference::_internal_add_predicted_prob_itl95_hi(float value) { + predicted_prob_itl95_hi_.Add(value); +} +inline void Explain_Inference::add_predicted_prob_itl95_hi(float value) { + _internal_add_predicted_prob_itl95_hi(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Inference.predicted_prob_itl95_hi) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::_internal_predicted_prob_itl95_hi() const { + return predicted_prob_itl95_hi_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Inference::predicted_prob_itl95_hi() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Inference.predicted_prob_itl95_hi) + return _internal_predicted_prob_itl95_hi(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::_internal_mutable_predicted_prob_itl95_hi() { + return &predicted_prob_itl95_hi_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Inference::mutable_predicted_prob_itl95_hi() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Inference.predicted_prob_itl95_hi) + return _internal_mutable_predicted_prob_itl95_hi(); +} + +// ------------------------------------------------------------------- + +// Explain_Explanation + +// optional string explain_method = 1; +inline bool Explain_Explanation::_internal_has_explain_method() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Explain_Explanation::has_explain_method() const { + return _internal_has_explain_method(); +} +inline void Explain_Explanation::clear_explain_method() { + explain_method_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Explain_Explanation::explain_method() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Explanation.explain_method) + return _internal_explain_method(); +} +inline void Explain_Explanation::set_explain_method(const std::string& value) { + _internal_set_explain_method(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Explanation.explain_method) +} +inline std::string* Explain_Explanation::mutable_explain_method() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Explanation.explain_method) + return _internal_mutable_explain_method(); +} +inline const std::string& Explain_Explanation::_internal_explain_method() const { + return explain_method_.Get(); +} +inline void Explain_Explanation::_internal_set_explain_method(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain_Explanation::set_explain_method(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + explain_method_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.Explanation.explain_method) +} +inline void Explain_Explanation::set_explain_method(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Explanation.explain_method) +} +inline void Explain_Explanation::set_explain_method(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Explanation.explain_method) +} +inline std::string* Explain_Explanation::_internal_mutable_explain_method() { + _has_bits_[0] |= 0x00000001u; + return explain_method_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain_Explanation::release_explain_method() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.Explanation.explain_method) + if (!_internal_has_explain_method()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return explain_method_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain_Explanation::set_allocated_explain_method(std::string* explain_method) { + if (explain_method != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + explain_method_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), explain_method, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.Explanation.explain_method) +} + +// optional int32 label = 2; +inline bool Explain_Explanation::_internal_has_label() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Explain_Explanation::has_label() const { + return _internal_has_label(); +} +inline void Explain_Explanation::clear_label() { + label_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_Explanation::_internal_label() const { + return label_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_Explanation::label() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Explanation.label) + return _internal_label(); +} +inline void Explain_Explanation::_internal_set_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000004u; + label_ = value; +} +inline void Explain_Explanation::set_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_label(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Explanation.label) +} + +// optional string heatmap_path = 3; +inline bool Explain_Explanation::_internal_has_heatmap_path() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Explain_Explanation::has_heatmap_path() const { + return _internal_has_heatmap_path(); +} +inline void Explain_Explanation::clear_heatmap_path() { + heatmap_path_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& Explain_Explanation::heatmap_path() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Explanation.heatmap_path) + return _internal_heatmap_path(); +} +inline void Explain_Explanation::set_heatmap_path(const std::string& value) { + _internal_set_heatmap_path(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Explanation.heatmap_path) +} +inline std::string* Explain_Explanation::mutable_heatmap_path() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Explanation.heatmap_path) + return _internal_mutable_heatmap_path(); +} +inline const std::string& Explain_Explanation::_internal_heatmap_path() const { + return heatmap_path_.Get(); +} +inline void Explain_Explanation::_internal_set_heatmap_path(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + heatmap_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain_Explanation::set_heatmap_path(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + heatmap_path_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.Explanation.heatmap_path) +} +inline void Explain_Explanation::set_heatmap_path(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + heatmap_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Explanation.heatmap_path) +} +inline void Explain_Explanation::set_heatmap_path(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + heatmap_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Explanation.heatmap_path) +} +inline std::string* Explain_Explanation::_internal_mutable_heatmap_path() { + _has_bits_[0] |= 0x00000002u; + return heatmap_path_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain_Explanation::release_heatmap_path() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.Explanation.heatmap_path) + if (!_internal_has_heatmap_path()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return heatmap_path_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain_Explanation::set_allocated_heatmap_path(std::string* heatmap_path) { + if (heatmap_path != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + heatmap_path_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), heatmap_path, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.Explanation.heatmap_path) +} + +// ------------------------------------------------------------------- + +// Explain_Benchmark + +// optional string benchmark_method = 1; +inline bool Explain_Benchmark::_internal_has_benchmark_method() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Explain_Benchmark::has_benchmark_method() const { + return _internal_has_benchmark_method(); +} +inline void Explain_Benchmark::clear_benchmark_method() { + benchmark_method_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Explain_Benchmark::benchmark_method() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Benchmark.benchmark_method) + return _internal_benchmark_method(); +} +inline void Explain_Benchmark::set_benchmark_method(const std::string& value) { + _internal_set_benchmark_method(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Benchmark.benchmark_method) +} +inline std::string* Explain_Benchmark::mutable_benchmark_method() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Benchmark.benchmark_method) + return _internal_mutable_benchmark_method(); +} +inline const std::string& Explain_Benchmark::_internal_benchmark_method() const { + return benchmark_method_.Get(); +} +inline void Explain_Benchmark::_internal_set_benchmark_method(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + benchmark_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain_Benchmark::set_benchmark_method(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + benchmark_method_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.Benchmark.benchmark_method) +} +inline void Explain_Benchmark::set_benchmark_method(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + benchmark_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Benchmark.benchmark_method) +} +inline void Explain_Benchmark::set_benchmark_method(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + benchmark_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Benchmark.benchmark_method) +} +inline std::string* Explain_Benchmark::_internal_mutable_benchmark_method() { + _has_bits_[0] |= 0x00000001u; + return benchmark_method_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain_Benchmark::release_benchmark_method() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.Benchmark.benchmark_method) + if (!_internal_has_benchmark_method()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return benchmark_method_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain_Benchmark::set_allocated_benchmark_method(std::string* benchmark_method) { + if (benchmark_method != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + benchmark_method_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), benchmark_method, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.Benchmark.benchmark_method) +} + +// optional string explain_method = 2; +inline bool Explain_Benchmark::_internal_has_explain_method() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Explain_Benchmark::has_explain_method() const { + return _internal_has_explain_method(); +} +inline void Explain_Benchmark::clear_explain_method() { + explain_method_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& Explain_Benchmark::explain_method() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Benchmark.explain_method) + return _internal_explain_method(); +} +inline void Explain_Benchmark::set_explain_method(const std::string& value) { + _internal_set_explain_method(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Benchmark.explain_method) +} +inline std::string* Explain_Benchmark::mutable_explain_method() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Benchmark.explain_method) + return _internal_mutable_explain_method(); +} +inline const std::string& Explain_Benchmark::_internal_explain_method() const { + return explain_method_.Get(); +} +inline void Explain_Benchmark::_internal_set_explain_method(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain_Benchmark::set_explain_method(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + explain_method_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.Benchmark.explain_method) +} +inline void Explain_Benchmark::set_explain_method(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Benchmark.explain_method) +} +inline void Explain_Benchmark::set_explain_method(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + explain_method_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Benchmark.explain_method) +} +inline std::string* Explain_Benchmark::_internal_mutable_explain_method() { + _has_bits_[0] |= 0x00000002u; + return explain_method_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain_Benchmark::release_explain_method() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.Benchmark.explain_method) + if (!_internal_has_explain_method()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return explain_method_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain_Benchmark::set_allocated_explain_method(std::string* explain_method) { + if (explain_method != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + explain_method_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), explain_method, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.Benchmark.explain_method) +} + +// optional float total_score = 3; +inline bool Explain_Benchmark::_internal_has_total_score() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool Explain_Benchmark::has_total_score() const { + return _internal_has_total_score(); +} +inline void Explain_Benchmark::clear_total_score() { + total_score_ = 0; + _has_bits_[0] &= ~0x00000004u; +} +inline float Explain_Benchmark::_internal_total_score() const { + return total_score_; +} +inline float Explain_Benchmark::total_score() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Benchmark.total_score) + return _internal_total_score(); +} +inline void Explain_Benchmark::_internal_set_total_score(float value) { + _has_bits_[0] |= 0x00000004u; + total_score_ = value; +} +inline void Explain_Benchmark::set_total_score(float value) { + _internal_set_total_score(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Benchmark.total_score) +} + +// repeated float label_score = 4; +inline int Explain_Benchmark::_internal_label_score_size() const { + return label_score_.size(); +} +inline int Explain_Benchmark::label_score_size() const { + return _internal_label_score_size(); +} +inline void Explain_Benchmark::clear_label_score() { + label_score_.Clear(); +} +inline float Explain_Benchmark::_internal_label_score(int index) const { + return label_score_.Get(index); +} +inline float Explain_Benchmark::label_score(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Benchmark.label_score) + return _internal_label_score(index); +} +inline void Explain_Benchmark::set_label_score(int index, float value) { + label_score_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Benchmark.label_score) +} +inline void Explain_Benchmark::_internal_add_label_score(float value) { + label_score_.Add(value); +} +inline void Explain_Benchmark::add_label_score(float value) { + _internal_add_label_score(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Benchmark.label_score) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Benchmark::_internal_label_score() const { + return label_score_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +Explain_Benchmark::label_score() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Benchmark.label_score) + return _internal_label_score(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Benchmark::_internal_mutable_label_score() { + return &label_score_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +Explain_Benchmark::mutable_label_score() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Benchmark.label_score) + return _internal_mutable_label_score(); +} + +// ------------------------------------------------------------------- + +// Explain_Metadata + +// repeated string label = 1; +inline int Explain_Metadata::_internal_label_size() const { + return label_.size(); +} +inline int Explain_Metadata::label_size() const { + return _internal_label_size(); +} +inline void Explain_Metadata::clear_label() { + label_.Clear(); +} +inline std::string* Explain_Metadata::add_label() { + // @@protoc_insertion_point(field_add_mutable:mindspore.irpb.Explain.Metadata.label) + return _internal_add_label(); +} +inline const std::string& Explain_Metadata::_internal_label(int index) const { + return label_.Get(index); +} +inline const std::string& Explain_Metadata::label(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Metadata.label) + return _internal_label(index); +} +inline std::string* Explain_Metadata::mutable_label(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Metadata.label) + return label_.Mutable(index); +} +inline void Explain_Metadata::set_label(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Metadata.label) + label_.Mutable(index)->assign(value); +} +inline void Explain_Metadata::set_label(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Metadata.label) + label_.Mutable(index)->assign(std::move(value)); +} +inline void Explain_Metadata::set_label(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + label_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Metadata.label) +} +inline void Explain_Metadata::set_label(int index, const char* value, size_t size) { + label_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Metadata.label) +} +inline std::string* Explain_Metadata::_internal_add_label() { + return label_.Add(); +} +inline void Explain_Metadata::add_label(const std::string& value) { + label_.Add()->assign(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Metadata.label) +} +inline void Explain_Metadata::add_label(std::string&& value) { + label_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Metadata.label) +} +inline void Explain_Metadata::add_label(const char* value) { + GOOGLE_DCHECK(value != nullptr); + label_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:mindspore.irpb.Explain.Metadata.label) +} +inline void Explain_Metadata::add_label(const char* value, size_t size) { + label_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:mindspore.irpb.Explain.Metadata.label) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +Explain_Metadata::label() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Metadata.label) + return label_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +Explain_Metadata::mutable_label() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Metadata.label) + return &label_; +} + +// repeated string explain_method = 2; +inline int Explain_Metadata::_internal_explain_method_size() const { + return explain_method_.size(); +} +inline int Explain_Metadata::explain_method_size() const { + return _internal_explain_method_size(); +} +inline void Explain_Metadata::clear_explain_method() { + explain_method_.Clear(); +} +inline std::string* Explain_Metadata::add_explain_method() { + // @@protoc_insertion_point(field_add_mutable:mindspore.irpb.Explain.Metadata.explain_method) + return _internal_add_explain_method(); +} +inline const std::string& Explain_Metadata::_internal_explain_method(int index) const { + return explain_method_.Get(index); +} +inline const std::string& Explain_Metadata::explain_method(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Metadata.explain_method) + return _internal_explain_method(index); +} +inline std::string* Explain_Metadata::mutable_explain_method(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Metadata.explain_method) + return explain_method_.Mutable(index); +} +inline void Explain_Metadata::set_explain_method(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Metadata.explain_method) + explain_method_.Mutable(index)->assign(value); +} +inline void Explain_Metadata::set_explain_method(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Metadata.explain_method) + explain_method_.Mutable(index)->assign(std::move(value)); +} +inline void Explain_Metadata::set_explain_method(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + explain_method_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Metadata.explain_method) +} +inline void Explain_Metadata::set_explain_method(int index, const char* value, size_t size) { + explain_method_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Metadata.explain_method) +} +inline std::string* Explain_Metadata::_internal_add_explain_method() { + return explain_method_.Add(); +} +inline void Explain_Metadata::add_explain_method(const std::string& value) { + explain_method_.Add()->assign(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Metadata.explain_method) +} +inline void Explain_Metadata::add_explain_method(std::string&& value) { + explain_method_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Metadata.explain_method) +} +inline void Explain_Metadata::add_explain_method(const char* value) { + GOOGLE_DCHECK(value != nullptr); + explain_method_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:mindspore.irpb.Explain.Metadata.explain_method) +} +inline void Explain_Metadata::add_explain_method(const char* value, size_t size) { + explain_method_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:mindspore.irpb.Explain.Metadata.explain_method) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +Explain_Metadata::explain_method() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Metadata.explain_method) + return explain_method_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +Explain_Metadata::mutable_explain_method() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Metadata.explain_method) + return &explain_method_; +} + +// repeated string benchmark_method = 3; +inline int Explain_Metadata::_internal_benchmark_method_size() const { + return benchmark_method_.size(); +} +inline int Explain_Metadata::benchmark_method_size() const { + return _internal_benchmark_method_size(); +} +inline void Explain_Metadata::clear_benchmark_method() { + benchmark_method_.Clear(); +} +inline std::string* Explain_Metadata::add_benchmark_method() { + // @@protoc_insertion_point(field_add_mutable:mindspore.irpb.Explain.Metadata.benchmark_method) + return _internal_add_benchmark_method(); +} +inline const std::string& Explain_Metadata::_internal_benchmark_method(int index) const { + return benchmark_method_.Get(index); +} +inline const std::string& Explain_Metadata::benchmark_method(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Metadata.benchmark_method) + return _internal_benchmark_method(index); +} +inline std::string* Explain_Metadata::mutable_benchmark_method(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Metadata.benchmark_method) + return benchmark_method_.Mutable(index); +} +inline void Explain_Metadata::set_benchmark_method(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Metadata.benchmark_method) + benchmark_method_.Mutable(index)->assign(value); +} +inline void Explain_Metadata::set_benchmark_method(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Metadata.benchmark_method) + benchmark_method_.Mutable(index)->assign(std::move(value)); +} +inline void Explain_Metadata::set_benchmark_method(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + benchmark_method_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Metadata.benchmark_method) +} +inline void Explain_Metadata::set_benchmark_method(int index, const char* value, size_t size) { + benchmark_method_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Metadata.benchmark_method) +} +inline std::string* Explain_Metadata::_internal_add_benchmark_method() { + return benchmark_method_.Add(); +} +inline void Explain_Metadata::add_benchmark_method(const std::string& value) { + benchmark_method_.Add()->assign(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Metadata.benchmark_method) +} +inline void Explain_Metadata::add_benchmark_method(std::string&& value) { + benchmark_method_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Metadata.benchmark_method) +} +inline void Explain_Metadata::add_benchmark_method(const char* value) { + GOOGLE_DCHECK(value != nullptr); + benchmark_method_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:mindspore.irpb.Explain.Metadata.benchmark_method) +} +inline void Explain_Metadata::add_benchmark_method(const char* value, size_t size) { + benchmark_method_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:mindspore.irpb.Explain.Metadata.benchmark_method) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +Explain_Metadata::benchmark_method() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Metadata.benchmark_method) + return benchmark_method_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +Explain_Metadata::mutable_benchmark_method() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Metadata.benchmark_method) + return &benchmark_method_; +} + +// ------------------------------------------------------------------- + +// Explain_HocLayer + +// optional float prob = 1; +inline bool Explain_HocLayer::_internal_has_prob() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Explain_HocLayer::has_prob() const { + return _internal_has_prob(); +} +inline void Explain_HocLayer::clear_prob() { + prob_ = 0; + _has_bits_[0] &= ~0x00000001u; +} +inline float Explain_HocLayer::_internal_prob() const { + return prob_; +} +inline float Explain_HocLayer::prob() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.HocLayer.prob) + return _internal_prob(); +} +inline void Explain_HocLayer::_internal_set_prob(float value) { + _has_bits_[0] |= 0x00000001u; + prob_ = value; +} +inline void Explain_HocLayer::set_prob(float value) { + _internal_set_prob(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.HocLayer.prob) +} + +// repeated int32 box = 2; +inline int Explain_HocLayer::_internal_box_size() const { + return box_.size(); +} +inline int Explain_HocLayer::box_size() const { + return _internal_box_size(); +} +inline void Explain_HocLayer::clear_box() { + box_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_HocLayer::_internal_box(int index) const { + return box_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_HocLayer::box(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.HocLayer.box) + return _internal_box(index); +} +inline void Explain_HocLayer::set_box(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + box_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.HocLayer.box) +} +inline void Explain_HocLayer::_internal_add_box(::PROTOBUF_NAMESPACE_ID::int32 value) { + box_.Add(value); +} +inline void Explain_HocLayer::add_box(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_box(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.HocLayer.box) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +Explain_HocLayer::_internal_box() const { + return box_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +Explain_HocLayer::box() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.HocLayer.box) + return _internal_box(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +Explain_HocLayer::_internal_mutable_box() { + return &box_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +Explain_HocLayer::mutable_box() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.HocLayer.box) + return _internal_mutable_box(); +} + +// ------------------------------------------------------------------- + +// Explain_Hoc + +// optional int32 label = 1; +inline bool Explain_Hoc::_internal_has_label() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Explain_Hoc::has_label() const { + return _internal_has_label(); +} +inline void Explain_Hoc::clear_label() { + label_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_Hoc::_internal_label() const { + return label_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain_Hoc::label() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Hoc.label) + return _internal_label(); +} +inline void Explain_Hoc::_internal_set_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000002u; + label_ = value; +} +inline void Explain_Hoc::set_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_label(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Hoc.label) +} + +// optional string mask = 2; +inline bool Explain_Hoc::_internal_has_mask() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Explain_Hoc::has_mask() const { + return _internal_has_mask(); +} +inline void Explain_Hoc::clear_mask() { + mask_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Explain_Hoc::mask() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Hoc.mask) + return _internal_mask(); +} +inline void Explain_Hoc::set_mask(const std::string& value) { + _internal_set_mask(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.Hoc.mask) +} +inline std::string* Explain_Hoc::mutable_mask() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Hoc.mask) + return _internal_mutable_mask(); +} +inline const std::string& Explain_Hoc::_internal_mask() const { + return mask_.Get(); +} +inline void Explain_Hoc::_internal_set_mask(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + mask_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain_Hoc::set_mask(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + mask_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.Hoc.mask) +} +inline void Explain_Hoc::set_mask(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + mask_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.Hoc.mask) +} +inline void Explain_Hoc::set_mask(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + mask_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.Hoc.mask) +} +inline std::string* Explain_Hoc::_internal_mutable_mask() { + _has_bits_[0] |= 0x00000001u; + return mask_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain_Hoc::release_mask() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.Hoc.mask) + if (!_internal_has_mask()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return mask_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain_Hoc::set_allocated_mask(std::string* mask) { + if (mask != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + mask_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), mask, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.Hoc.mask) +} + +// repeated .mindspore.irpb.Explain.HocLayer layer = 3; +inline int Explain_Hoc::_internal_layer_size() const { + return layer_.size(); +} +inline int Explain_Hoc::layer_size() const { + return _internal_layer_size(); +} +inline void Explain_Hoc::clear_layer() { + layer_.Clear(); +} +inline ::mindspore::irpb::Explain_HocLayer* Explain_Hoc::mutable_layer(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.Hoc.layer) + return layer_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_HocLayer >* +Explain_Hoc::mutable_layer() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.Hoc.layer) + return &layer_; +} +inline const ::mindspore::irpb::Explain_HocLayer& Explain_Hoc::_internal_layer(int index) const { + return layer_.Get(index); +} +inline const ::mindspore::irpb::Explain_HocLayer& Explain_Hoc::layer(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.Hoc.layer) + return _internal_layer(index); +} +inline ::mindspore::irpb::Explain_HocLayer* Explain_Hoc::_internal_add_layer() { + return layer_.Add(); +} +inline ::mindspore::irpb::Explain_HocLayer* Explain_Hoc::add_layer() { + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.Hoc.layer) + return _internal_add_layer(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_HocLayer >& +Explain_Hoc::layer() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.Hoc.layer) + return layer_; +} + +// ------------------------------------------------------------------- + +// Explain + +// optional int32 sample_id = 1; +inline bool Explain::_internal_has_sample_id() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool Explain::has_sample_id() const { + return _internal_has_sample_id(); +} +inline void Explain::clear_sample_id() { + sample_id_ = 0; + _has_bits_[0] &= ~0x00000010u; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain::_internal_sample_id() const { + return sample_id_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain::sample_id() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.sample_id) + return _internal_sample_id(); +} +inline void Explain::_internal_set_sample_id(::PROTOBUF_NAMESPACE_ID::int32 value) { + _has_bits_[0] |= 0x00000010u; + sample_id_ = value; +} +inline void Explain::set_sample_id(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_sample_id(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.sample_id) +} + +// optional string image_path = 2; +inline bool Explain::_internal_has_image_path() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool Explain::has_image_path() const { + return _internal_has_image_path(); +} +inline void Explain::clear_image_path() { + image_path_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& Explain::image_path() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.image_path) + return _internal_image_path(); +} +inline void Explain::set_image_path(const std::string& value) { + _internal_set_image_path(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.image_path) +} +inline std::string* Explain::mutable_image_path() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.image_path) + return _internal_mutable_image_path(); +} +inline const std::string& Explain::_internal_image_path() const { + return image_path_.Get(); +} +inline void Explain::_internal_set_image_path(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + image_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain::set_image_path(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + image_path_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.image_path) +} +inline void Explain::set_image_path(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + image_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.image_path) +} +inline void Explain::set_image_path(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + image_path_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.image_path) +} +inline std::string* Explain::_internal_mutable_image_path() { + _has_bits_[0] |= 0x00000001u; + return image_path_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain::release_image_path() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.image_path) + if (!_internal_has_image_path()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return image_path_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain::set_allocated_image_path(std::string* image_path) { + if (image_path != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + image_path_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), image_path, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.image_path) +} + +// repeated int32 ground_truth_label = 3; +inline int Explain::_internal_ground_truth_label_size() const { + return ground_truth_label_.size(); +} +inline int Explain::ground_truth_label_size() const { + return _internal_ground_truth_label_size(); +} +inline void Explain::clear_ground_truth_label() { + ground_truth_label_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain::_internal_ground_truth_label(int index) const { + return ground_truth_label_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Explain::ground_truth_label(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.ground_truth_label) + return _internal_ground_truth_label(index); +} +inline void Explain::set_ground_truth_label(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + ground_truth_label_.Set(index, value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.ground_truth_label) +} +inline void Explain::_internal_add_ground_truth_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + ground_truth_label_.Add(value); +} +inline void Explain::add_ground_truth_label(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_ground_truth_label(value); + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.ground_truth_label) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +Explain::_internal_ground_truth_label() const { + return ground_truth_label_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +Explain::ground_truth_label() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.ground_truth_label) + return _internal_ground_truth_label(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +Explain::_internal_mutable_ground_truth_label() { + return &ground_truth_label_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +Explain::mutable_ground_truth_label() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.ground_truth_label) + return _internal_mutable_ground_truth_label(); +} + +// optional .mindspore.irpb.Explain.Inference inference = 4; +inline bool Explain::_internal_has_inference() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || inference_ != nullptr); + return value; +} +inline bool Explain::has_inference() const { + return _internal_has_inference(); +} +inline void Explain::clear_inference() { + if (inference_ != nullptr) inference_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::mindspore::irpb::Explain_Inference& Explain::_internal_inference() const { + const ::mindspore::irpb::Explain_Inference* p = inference_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_Explain_Inference_default_instance_); +} +inline const ::mindspore::irpb::Explain_Inference& Explain::inference() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.inference) + return _internal_inference(); +} +inline void Explain::unsafe_arena_set_allocated_inference( + ::mindspore::irpb::Explain_Inference* inference) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(inference_); + } + inference_ = inference; + if (inference) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Explain.inference) +} +inline ::mindspore::irpb::Explain_Inference* Explain::release_inference() { + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::Explain_Inference* temp = inference_; + inference_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::Explain_Inference* Explain::unsafe_arena_release_inference() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.inference) + _has_bits_[0] &= ~0x00000004u; + ::mindspore::irpb::Explain_Inference* temp = inference_; + inference_ = nullptr; + return temp; +} +inline ::mindspore::irpb::Explain_Inference* Explain::_internal_mutable_inference() { + _has_bits_[0] |= 0x00000004u; + if (inference_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::Explain_Inference>(GetArena()); + inference_ = p; + } + return inference_; +} +inline ::mindspore::irpb::Explain_Inference* Explain::mutable_inference() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.inference) + return _internal_mutable_inference(); +} +inline void Explain::set_allocated_inference(::mindspore::irpb::Explain_Inference* inference) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete inference_; + } + if (inference) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(inference); + if (message_arena != submessage_arena) { + inference = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, inference, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + inference_ = inference; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.inference) +} + +// repeated .mindspore.irpb.Explain.Explanation explanation = 5; +inline int Explain::_internal_explanation_size() const { + return explanation_.size(); +} +inline int Explain::explanation_size() const { + return _internal_explanation_size(); +} +inline void Explain::clear_explanation() { + explanation_.Clear(); +} +inline ::mindspore::irpb::Explain_Explanation* Explain::mutable_explanation(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.explanation) + return explanation_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Explanation >* +Explain::mutable_explanation() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.explanation) + return &explanation_; +} +inline const ::mindspore::irpb::Explain_Explanation& Explain::_internal_explanation(int index) const { + return explanation_.Get(index); +} +inline const ::mindspore::irpb::Explain_Explanation& Explain::explanation(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.explanation) + return _internal_explanation(index); +} +inline ::mindspore::irpb::Explain_Explanation* Explain::_internal_add_explanation() { + return explanation_.Add(); +} +inline ::mindspore::irpb::Explain_Explanation* Explain::add_explanation() { + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.explanation) + return _internal_add_explanation(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Explanation >& +Explain::explanation() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.explanation) + return explanation_; +} + +// repeated .mindspore.irpb.Explain.Benchmark benchmark = 6; +inline int Explain::_internal_benchmark_size() const { + return benchmark_.size(); +} +inline int Explain::benchmark_size() const { + return _internal_benchmark_size(); +} +inline void Explain::clear_benchmark() { + benchmark_.Clear(); +} +inline ::mindspore::irpb::Explain_Benchmark* Explain::mutable_benchmark(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.benchmark) + return benchmark_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Benchmark >* +Explain::mutable_benchmark() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.benchmark) + return &benchmark_; +} +inline const ::mindspore::irpb::Explain_Benchmark& Explain::_internal_benchmark(int index) const { + return benchmark_.Get(index); +} +inline const ::mindspore::irpb::Explain_Benchmark& Explain::benchmark(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.benchmark) + return _internal_benchmark(index); +} +inline ::mindspore::irpb::Explain_Benchmark* Explain::_internal_add_benchmark() { + return benchmark_.Add(); +} +inline ::mindspore::irpb::Explain_Benchmark* Explain::add_benchmark() { + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.benchmark) + return _internal_add_benchmark(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Benchmark >& +Explain::benchmark() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.benchmark) + return benchmark_; +} + +// optional .mindspore.irpb.Explain.Metadata metadata = 7; +inline bool Explain::_internal_has_metadata() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || metadata_ != nullptr); + return value; +} +inline bool Explain::has_metadata() const { + return _internal_has_metadata(); +} +inline void Explain::clear_metadata() { + if (metadata_ != nullptr) metadata_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const ::mindspore::irpb::Explain_Metadata& Explain::_internal_metadata() const { + const ::mindspore::irpb::Explain_Metadata* p = metadata_; + return p != nullptr ? *p : *reinterpret_cast( + &::mindspore::irpb::_Explain_Metadata_default_instance_); +} +inline const ::mindspore::irpb::Explain_Metadata& Explain::metadata() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.metadata) + return _internal_metadata(); +} +inline void Explain::unsafe_arena_set_allocated_metadata( + ::mindspore::irpb::Explain_Metadata* metadata) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(metadata_); + } + metadata_ = metadata; + if (metadata) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:mindspore.irpb.Explain.metadata) +} +inline ::mindspore::irpb::Explain_Metadata* Explain::release_metadata() { + _has_bits_[0] &= ~0x00000008u; + ::mindspore::irpb::Explain_Metadata* temp = metadata_; + metadata_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::mindspore::irpb::Explain_Metadata* Explain::unsafe_arena_release_metadata() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.metadata) + _has_bits_[0] &= ~0x00000008u; + ::mindspore::irpb::Explain_Metadata* temp = metadata_; + metadata_ = nullptr; + return temp; +} +inline ::mindspore::irpb::Explain_Metadata* Explain::_internal_mutable_metadata() { + _has_bits_[0] |= 0x00000008u; + if (metadata_ == nullptr) { + auto* p = CreateMaybeMessage<::mindspore::irpb::Explain_Metadata>(GetArena()); + metadata_ = p; + } + return metadata_; +} +inline ::mindspore::irpb::Explain_Metadata* Explain::mutable_metadata() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.metadata) + return _internal_mutable_metadata(); +} +inline void Explain::set_allocated_metadata(::mindspore::irpb::Explain_Metadata* metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete metadata_; + } + if (metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(metadata); + if (message_arena != submessage_arena) { + metadata = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, metadata, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + metadata_ = metadata; + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.metadata) +} + +// optional string status = 8; +inline bool Explain::_internal_has_status() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool Explain::has_status() const { + return _internal_has_status(); +} +inline void Explain::clear_status() { + status_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& Explain::status() const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.status) + return _internal_status(); +} +inline void Explain::set_status(const std::string& value) { + _internal_set_status(value); + // @@protoc_insertion_point(field_set:mindspore.irpb.Explain.status) +} +inline std::string* Explain::mutable_status() { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.status) + return _internal_mutable_status(); +} +inline const std::string& Explain::_internal_status() const { + return status_.Get(); +} +inline void Explain::_internal_set_status(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + status_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Explain::set_status(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + status_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:mindspore.irpb.Explain.status) +} +inline void Explain::set_status(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + status_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:mindspore.irpb.Explain.status) +} +inline void Explain::set_status(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + status_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:mindspore.irpb.Explain.status) +} +inline std::string* Explain::_internal_mutable_status() { + _has_bits_[0] |= 0x00000002u; + return status_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Explain::release_status() { + // @@protoc_insertion_point(field_release:mindspore.irpb.Explain.status) + if (!_internal_has_status()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return status_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Explain::set_allocated_status(std::string* status) { + if (status != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + status_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), status, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:mindspore.irpb.Explain.status) +} + +// repeated .mindspore.irpb.Explain.Hoc hoc = 9; +inline int Explain::_internal_hoc_size() const { + return hoc_.size(); +} +inline int Explain::hoc_size() const { + return _internal_hoc_size(); +} +inline void Explain::clear_hoc() { + hoc_.Clear(); +} +inline ::mindspore::irpb::Explain_Hoc* Explain::mutable_hoc(int index) { + // @@protoc_insertion_point(field_mutable:mindspore.irpb.Explain.hoc) + return hoc_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Hoc >* +Explain::mutable_hoc() { + // @@protoc_insertion_point(field_mutable_list:mindspore.irpb.Explain.hoc) + return &hoc_; +} +inline const ::mindspore::irpb::Explain_Hoc& Explain::_internal_hoc(int index) const { + return hoc_.Get(index); +} +inline const ::mindspore::irpb::Explain_Hoc& Explain::hoc(int index) const { + // @@protoc_insertion_point(field_get:mindspore.irpb.Explain.hoc) + return _internal_hoc(index); +} +inline ::mindspore::irpb::Explain_Hoc* Explain::_internal_add_hoc() { + return hoc_.Add(); +} +inline ::mindspore::irpb::Explain_Hoc* Explain::add_hoc() { + // @@protoc_insertion_point(field_add:mindspore.irpb.Explain.hoc) + return _internal_add_hoc(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::mindspore::irpb::Explain_Hoc >& +Explain::hoc() const { + // @@protoc_insertion_point(field_list:mindspore.irpb.Explain.hoc) + return hoc_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace irpb +} // namespace mindspore + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_mindspore_5fsummary_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/mindspore_summary.proto b/plugins/mindstudio-insight-plugins/proto/mindspore_summary.proto new file mode 100644 index 0000000000000000000000000000000000000000..b9c67279d2c06a12b86932ed4bf8edb3fbe1da22 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/mindspore_summary.proto @@ -0,0 +1,188 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package mindspore.irpb; +option cc_enable_arenas = true; + +// The ANF IR define, include the tensor and graph define +import "mindspore_anf_ir.proto"; + +// Event Protocol buffer, Top define +message Event { + // Timestamp + required double wall_time = 1; + + // The step of train. + optional int64 step = 2; + + oneof what { + // An event file was started, with the specified version. + // Now version is "MindSpore.Event:1" + string version = 3; + + // GraphDef. + GraphProto graph_def = 4; + + // Summary data + Summary summary = 5; + + Explain explain = 6; + } +} + +message LossLandscape{ + message Point { + optional TensorProto x = 1; + optional TensorProto y = 2; + optional TensorProto z = 3; + } + + message LossPath { + repeated int32 intervals = 1; // step intervals or epoch intervals + optional Point points = 2; + } + + message Metadata { + optional string decomposition = 1; + optional string unit = 2; // step or epoch + optional int32 step_per_epoch = 3; + } + + optional Point landscape = 1; + optional LossPath loss_path = 2; + optional Metadata metadata = 3; // maybe only record by the first value + optional Point convergence_point = 4; + +} + +// A Summary is a set of named values that be produced regularly during training +message Summary { + message Image { + // Dimensions of the image. + required int32 height = 1; + required int32 width = 2; + // Valid colorspace values are: + // 1 - grayscale type + // 2 - grayscale + alpha type + // 3 - RGB type + // 4 - RGBA type + // 5 - DIGITAL_YUV type + // 6 - BGRA type + required int32 colorspace = 3; + // Image data in encoded format. Now only support the RGB. + required bytes encoded_image = 4; + } + + message Histogram { + message bucket{ + // Count number of values fallen in [left, left + width). + // For the right most bucket, range is [left, left + width]. + required double left = 1; + required double width = 2; + required int64 count = 3; + } + + repeated bucket buckets = 1; + optional int64 nan_count = 2; + optional int64 pos_inf_count = 3; + optional int64 neg_inf_count = 4; + + // max, min, sum will not take nan and inf into account. + // If there is no valid value in tensor, max will be nan, min will be nan, sum will be 0. + optional double max = 5; + optional double min = 6; + optional double sum = 7; + + // total number of values, including nan and inf + optional int64 count = 8; + } + + message Value { + // Tag name for the data. + required string tag = 1; + + // Value associated with the tag. + oneof value { + float scalar_value = 3; + Image image = 4; + TensorProto tensor = 8; + Histogram histogram = 9; + LossLandscape loss_landscape = 10; + } + } + + // Set of values for the summary. + repeated Value value = 1; +} + +message Explain { + message Inference{ + repeated float ground_truth_prob = 1; + repeated int32 predicted_label = 2; + repeated float predicted_prob = 3; + repeated float ground_truth_prob_sd = 4; + repeated float ground_truth_prob_itl95_low = 5; + repeated float ground_truth_prob_itl95_hi = 6; + repeated float predicted_prob_sd = 7; + repeated float predicted_prob_itl95_low = 8; + repeated float predicted_prob_itl95_hi = 9; + } + + message Explanation{ + optional string explain_method = 1; + optional int32 label = 2; + optional string heatmap_path = 3; + } + + message Benchmark{ + optional string benchmark_method = 1; + optional string explain_method = 2; + optional float total_score = 3; + repeated float label_score = 4; + } + + message Metadata{ + repeated string label = 1; + repeated string explain_method = 2; + repeated string benchmark_method = 3; + } + + message HocLayer { + optional float prob = 1; + repeated int32 box = 2; // List of repeated x, y, w, h + } + + message Hoc { + optional int32 label = 1; + optional string mask = 2; + repeated HocLayer layer = 3; + } + + optional int32 sample_id = 1; + optional string image_path = 2; // The Metadata and image path must have one fill in + repeated int32 ground_truth_label = 3; + + optional Inference inference = 4; + repeated Explanation explanation = 5; + repeated Benchmark benchmark = 6; + + optional Metadata metadata = 7; + optional string status = 8; // enum value: run, end + + repeated Hoc hoc = 9; // hierarchical occlusion counterfactual +} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/proto/resource_handle.pb.cc b/plugins/mindstudio-insight-plugins/proto/resource_handle.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..f1fe6e682f74f62c3b608e4316fd5572909e5c51 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/resource_handle.pb.cc @@ -0,0 +1,776 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: resource_handle.proto + +#include "resource_handle.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_resource_5fhandle_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_tensor_5fshape_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_tensor_5fshape_2eproto; +namespace tensorboard { +class ResourceHandleProto_DtypeAndShapeDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ResourceHandleProto_DtypeAndShape_default_instance_; +class ResourceHandleProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ResourceHandleProto_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_ResourceHandleProto_resource_5fhandle_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_ResourceHandleProto_default_instance_; + new (ptr) ::tensorboard::ResourceHandleProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::ResourceHandleProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_ResourceHandleProto_resource_5fhandle_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_ResourceHandleProto_resource_5fhandle_2eproto}, { + &scc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto.base,}}; + +static void InitDefaultsscc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_ResourceHandleProto_DtypeAndShape_default_instance_; + new (ptr) ::tensorboard::ResourceHandleProto_DtypeAndShape(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::ResourceHandleProto_DtypeAndShape::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto}, { + &scc_info_TensorShapeProto_tensor_5fshape_2eproto.base,}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_resource_5fhandle_2eproto[2]; +static constexpr ::PROTOBUF_NAMESPACE_ID::EnumDescriptor const** file_level_enum_descriptors_resource_5fhandle_2eproto = nullptr; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_resource_5fhandle_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_resource_5fhandle_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto_DtypeAndShape, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto_DtypeAndShape, dtype_), + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto_DtypeAndShape, shape_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, device_), + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, container_), + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, name_), + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, hash_code_), + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, maybe_type_name_), + PROTOBUF_FIELD_OFFSET(::tensorboard::ResourceHandleProto, dtypes_and_shapes_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::ResourceHandleProto_DtypeAndShape)}, + { 7, -1, sizeof(::tensorboard::ResourceHandleProto)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_ResourceHandleProto_DtypeAndShape_default_instance_), + reinterpret_cast(&::tensorboard::_ResourceHandleProto_default_instance_), +}; + +const char descriptor_table_protodef_resource_5fhandle_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\025resource_handle.proto\022\013tensorboard\032\022te" + "nsor_shape.proto\032\013types.proto\"\250\002\n\023Resour" + "ceHandleProto\022\016\n\006device\030\001 \001(\t\022\021\n\tcontain" + "er\030\002 \001(\t\022\014\n\004name\030\003 \001(\t\022\021\n\thash_code\030\004 \001(" + "\004\022\027\n\017maybe_type_name\030\005 \001(\t\022I\n\021dtypes_and" + "_shapes\030\006 \003(\0132..tensorboard.ResourceHand" + "leProto.DtypeAndShape\032c\n\rDtypeAndShape\022$" + "\n\005dtype\030\001 \001(\0162\025.tensorboard.DataType\022,\n\005" + "shape\030\002 \001(\0132\035.tensorboard.TensorShapePro" + "toJ\004\010\007\020\010B\207\001\n\030org.tensorflow.frameworkB\016R" + "esourceHandleP\001ZVgithub.com/tensorflow/t" + "ensorflow/tensorflow/go/core/framework/r" + "esource_handle_go_proto\370\001\001b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_resource_5fhandle_2eproto_deps[2] = { + &::descriptor_table_tensor_5fshape_2eproto, + &::descriptor_table_types_2eproto, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_resource_5fhandle_2eproto_sccs[2] = { + &scc_info_ResourceHandleProto_resource_5fhandle_2eproto.base, + &scc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_resource_5fhandle_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_resource_5fhandle_2eproto = { + false, false, descriptor_table_protodef_resource_5fhandle_2eproto, "resource_handle.proto", 514, + &descriptor_table_resource_5fhandle_2eproto_once, descriptor_table_resource_5fhandle_2eproto_sccs, descriptor_table_resource_5fhandle_2eproto_deps, 2, 2, + schemas, file_default_instances, TableStruct_resource_5fhandle_2eproto::offsets, + file_level_metadata_resource_5fhandle_2eproto, 2, file_level_enum_descriptors_resource_5fhandle_2eproto, file_level_service_descriptors_resource_5fhandle_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_resource_5fhandle_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_resource_5fhandle_2eproto)), true); +namespace tensorboard { + +// =================================================================== + +void ResourceHandleProto_DtypeAndShape::InitAsDefaultInstance() { + ::tensorboard::_ResourceHandleProto_DtypeAndShape_default_instance_._instance.get_mutable()->shape_ = const_cast< ::tensorboard::TensorShapeProto*>( + ::tensorboard::TensorShapeProto::internal_default_instance()); +} +class ResourceHandleProto_DtypeAndShape::_Internal { + public: + static const ::tensorboard::TensorShapeProto& shape(const ResourceHandleProto_DtypeAndShape* msg); +}; + +const ::tensorboard::TensorShapeProto& +ResourceHandleProto_DtypeAndShape::_Internal::shape(const ResourceHandleProto_DtypeAndShape* msg) { + return *msg->shape_; +} +void ResourceHandleProto_DtypeAndShape::clear_shape() { + if (GetArena() == nullptr && shape_ != nullptr) { + delete shape_; + } + shape_ = nullptr; +} +ResourceHandleProto_DtypeAndShape::ResourceHandleProto_DtypeAndShape(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.ResourceHandleProto.DtypeAndShape) +} +ResourceHandleProto_DtypeAndShape::ResourceHandleProto_DtypeAndShape(const ResourceHandleProto_DtypeAndShape& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_shape()) { + shape_ = new ::tensorboard::TensorShapeProto(*from.shape_); + } else { + shape_ = nullptr; + } + dtype_ = from.dtype_; + // @@protoc_insertion_point(copy_constructor:tensorboard.ResourceHandleProto.DtypeAndShape) +} + +void ResourceHandleProto_DtypeAndShape::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto.base); + ::memset(&shape_, 0, static_cast( + reinterpret_cast(&dtype_) - + reinterpret_cast(&shape_)) + sizeof(dtype_)); +} + +ResourceHandleProto_DtypeAndShape::~ResourceHandleProto_DtypeAndShape() { + // @@protoc_insertion_point(destructor:tensorboard.ResourceHandleProto.DtypeAndShape) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ResourceHandleProto_DtypeAndShape::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete shape_; +} + +void ResourceHandleProto_DtypeAndShape::ArenaDtor(void* object) { + ResourceHandleProto_DtypeAndShape* _this = reinterpret_cast< ResourceHandleProto_DtypeAndShape* >(object); + (void)_this; +} +void ResourceHandleProto_DtypeAndShape::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ResourceHandleProto_DtypeAndShape::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ResourceHandleProto_DtypeAndShape& ResourceHandleProto_DtypeAndShape::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_ResourceHandleProto_DtypeAndShape_resource_5fhandle_2eproto.base); + return *internal_default_instance(); +} + + +void ResourceHandleProto_DtypeAndShape::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.ResourceHandleProto.DtypeAndShape) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + if (GetArena() == nullptr && shape_ != nullptr) { + delete shape_; + } + shape_ = nullptr; + dtype_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ResourceHandleProto_DtypeAndShape::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.DataType dtype = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_dtype(static_cast<::tensorboard::DataType>(val)); + } else goto handle_unusual; + continue; + // .tensorboard.TensorShapeProto shape = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_shape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ResourceHandleProto_DtypeAndShape::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.ResourceHandleProto.DtypeAndShape) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.DataType dtype = 1; + if (this->dtype() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_dtype(), target); + } + + // .tensorboard.TensorShapeProto shape = 2; + if (this->has_shape()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::shape(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.ResourceHandleProto.DtypeAndShape) + return target; +} + +size_t ResourceHandleProto_DtypeAndShape::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.ResourceHandleProto.DtypeAndShape) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // .tensorboard.TensorShapeProto shape = 2; + if (this->has_shape()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *shape_); + } + + // .tensorboard.DataType dtype = 1; + if (this->dtype() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_dtype()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ResourceHandleProto_DtypeAndShape::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.ResourceHandleProto.DtypeAndShape) + GOOGLE_DCHECK_NE(&from, this); + const ResourceHandleProto_DtypeAndShape* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.ResourceHandleProto.DtypeAndShape) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.ResourceHandleProto.DtypeAndShape) + MergeFrom(*source); + } +} + +void ResourceHandleProto_DtypeAndShape::MergeFrom(const ResourceHandleProto_DtypeAndShape& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.ResourceHandleProto.DtypeAndShape) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.has_shape()) { + _internal_mutable_shape()->::tensorboard::TensorShapeProto::MergeFrom(from._internal_shape()); + } + if (from.dtype() != 0) { + _internal_set_dtype(from._internal_dtype()); + } +} + +void ResourceHandleProto_DtypeAndShape::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.ResourceHandleProto.DtypeAndShape) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ResourceHandleProto_DtypeAndShape::CopyFrom(const ResourceHandleProto_DtypeAndShape& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.ResourceHandleProto.DtypeAndShape) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ResourceHandleProto_DtypeAndShape::IsInitialized() const { + return true; +} + +void ResourceHandleProto_DtypeAndShape::InternalSwap(ResourceHandleProto_DtypeAndShape* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(ResourceHandleProto_DtypeAndShape, dtype_) + + sizeof(ResourceHandleProto_DtypeAndShape::dtype_) + - PROTOBUF_FIELD_OFFSET(ResourceHandleProto_DtypeAndShape, shape_)>( + reinterpret_cast(&shape_), + reinterpret_cast(&other->shape_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ResourceHandleProto_DtypeAndShape::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void ResourceHandleProto::InitAsDefaultInstance() { +} +class ResourceHandleProto::_Internal { + public: +}; + +ResourceHandleProto::ResourceHandleProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + dtypes_and_shapes_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.ResourceHandleProto) +} +ResourceHandleProto::ResourceHandleProto(const ResourceHandleProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + dtypes_and_shapes_(from.dtypes_and_shapes_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + device_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_device().empty()) { + device_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_device(), + GetArena()); + } + container_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_container().empty()) { + container_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_container(), + GetArena()); + } + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_name().empty()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + maybe_type_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_maybe_type_name().empty()) { + maybe_type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_maybe_type_name(), + GetArena()); + } + hash_code_ = from.hash_code_; + // @@protoc_insertion_point(copy_constructor:tensorboard.ResourceHandleProto) +} + +void ResourceHandleProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_ResourceHandleProto_resource_5fhandle_2eproto.base); + device_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + container_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + maybe_type_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + hash_code_ = PROTOBUF_ULONGLONG(0); +} + +ResourceHandleProto::~ResourceHandleProto() { + // @@protoc_insertion_point(destructor:tensorboard.ResourceHandleProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ResourceHandleProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + device_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + container_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + maybe_type_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void ResourceHandleProto::ArenaDtor(void* object) { + ResourceHandleProto* _this = reinterpret_cast< ResourceHandleProto* >(object); + (void)_this; +} +void ResourceHandleProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ResourceHandleProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ResourceHandleProto& ResourceHandleProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_ResourceHandleProto_resource_5fhandle_2eproto.base); + return *internal_default_instance(); +} + + +void ResourceHandleProto::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.ResourceHandleProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + dtypes_and_shapes_.Clear(); + device_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + container_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + maybe_type_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + hash_code_ = PROTOBUF_ULONGLONG(0); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ResourceHandleProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string device = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_device(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.ResourceHandleProto.device")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string container = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_container(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.ResourceHandleProto.container")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string name = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.ResourceHandleProto.name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // uint64 hash_code = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32)) { + hash_code_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string maybe_type_name = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + auto str = _internal_mutable_maybe_type_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.ResourceHandleProto.maybe_type_name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .tensorboard.ResourceHandleProto.DtypeAndShape dtypes_and_shapes = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_dtypes_and_shapes(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<50>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ResourceHandleProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.ResourceHandleProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string device = 1; + if (this->device().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_device().data(), static_cast(this->_internal_device().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.ResourceHandleProto.device"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_device(), target); + } + + // string container = 2; + if (this->container().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_container().data(), static_cast(this->_internal_container().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.ResourceHandleProto.container"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_container(), target); + } + + // string name = 3; + if (this->name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.ResourceHandleProto.name"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_name(), target); + } + + // uint64 hash_code = 4; + if (this->hash_code() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt64ToArray(4, this->_internal_hash_code(), target); + } + + // string maybe_type_name = 5; + if (this->maybe_type_name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_maybe_type_name().data(), static_cast(this->_internal_maybe_type_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.ResourceHandleProto.maybe_type_name"); + target = stream->WriteStringMaybeAliased( + 5, this->_internal_maybe_type_name(), target); + } + + // repeated .tensorboard.ResourceHandleProto.DtypeAndShape dtypes_and_shapes = 6; + for (unsigned int i = 0, + n = static_cast(this->_internal_dtypes_and_shapes_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(6, this->_internal_dtypes_and_shapes(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.ResourceHandleProto) + return target; +} + +size_t ResourceHandleProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.ResourceHandleProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .tensorboard.ResourceHandleProto.DtypeAndShape dtypes_and_shapes = 6; + total_size += 1UL * this->_internal_dtypes_and_shapes_size(); + for (const auto& msg : this->dtypes_and_shapes_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // string device = 1; + if (this->device().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_device()); + } + + // string container = 2; + if (this->container().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_container()); + } + + // string name = 3; + if (this->name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // string maybe_type_name = 5; + if (this->maybe_type_name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_maybe_type_name()); + } + + // uint64 hash_code = 4; + if (this->hash_code() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt64Size( + this->_internal_hash_code()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ResourceHandleProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.ResourceHandleProto) + GOOGLE_DCHECK_NE(&from, this); + const ResourceHandleProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.ResourceHandleProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.ResourceHandleProto) + MergeFrom(*source); + } +} + +void ResourceHandleProto::MergeFrom(const ResourceHandleProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.ResourceHandleProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + dtypes_and_shapes_.MergeFrom(from.dtypes_and_shapes_); + if (from.device().size() > 0) { + _internal_set_device(from._internal_device()); + } + if (from.container().size() > 0) { + _internal_set_container(from._internal_container()); + } + if (from.name().size() > 0) { + _internal_set_name(from._internal_name()); + } + if (from.maybe_type_name().size() > 0) { + _internal_set_maybe_type_name(from._internal_maybe_type_name()); + } + if (from.hash_code() != 0) { + _internal_set_hash_code(from._internal_hash_code()); + } +} + +void ResourceHandleProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.ResourceHandleProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ResourceHandleProto::CopyFrom(const ResourceHandleProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.ResourceHandleProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ResourceHandleProto::IsInitialized() const { + return true; +} + +void ResourceHandleProto::InternalSwap(ResourceHandleProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + dtypes_and_shapes_.InternalSwap(&other->dtypes_and_shapes_); + device_.Swap(&other->device_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + container_.Swap(&other->container_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + maybe_type_name_.Swap(&other->maybe_type_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(hash_code_, other->hash_code_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ResourceHandleProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::ResourceHandleProto_DtypeAndShape* Arena::CreateMaybeMessage< ::tensorboard::ResourceHandleProto_DtypeAndShape >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::ResourceHandleProto_DtypeAndShape >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::ResourceHandleProto* Arena::CreateMaybeMessage< ::tensorboard::ResourceHandleProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::ResourceHandleProto >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/resource_handle.pb.h b/plugins/mindstudio-insight-plugins/proto/resource_handle.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..f4407a57f9e0822414239984bc78ad268d60c014 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/resource_handle.pb.h @@ -0,0 +1,893 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: resource_handle.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_resource_5fhandle_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_resource_5fhandle_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include "tensor_shape.pb.h" +#include "types.pb.h" +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_resource_5fhandle_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_resource_5fhandle_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[2] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_resource_5fhandle_2eproto; +namespace tensorboard { +class ResourceHandleProto; +class ResourceHandleProtoDefaultTypeInternal; +extern ResourceHandleProtoDefaultTypeInternal _ResourceHandleProto_default_instance_; +class ResourceHandleProto_DtypeAndShape; +class ResourceHandleProto_DtypeAndShapeDefaultTypeInternal; +extern ResourceHandleProto_DtypeAndShapeDefaultTypeInternal _ResourceHandleProto_DtypeAndShape_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::ResourceHandleProto* Arena::CreateMaybeMessage<::tensorboard::ResourceHandleProto>(Arena*); +template<> ::tensorboard::ResourceHandleProto_DtypeAndShape* Arena::CreateMaybeMessage<::tensorboard::ResourceHandleProto_DtypeAndShape>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +// =================================================================== + +class ResourceHandleProto_DtypeAndShape PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.ResourceHandleProto.DtypeAndShape) */ { + public: + inline ResourceHandleProto_DtypeAndShape() : ResourceHandleProto_DtypeAndShape(nullptr) {} + virtual ~ResourceHandleProto_DtypeAndShape(); + + ResourceHandleProto_DtypeAndShape(const ResourceHandleProto_DtypeAndShape& from); + ResourceHandleProto_DtypeAndShape(ResourceHandleProto_DtypeAndShape&& from) noexcept + : ResourceHandleProto_DtypeAndShape() { + *this = ::std::move(from); + } + + inline ResourceHandleProto_DtypeAndShape& operator=(const ResourceHandleProto_DtypeAndShape& from) { + CopyFrom(from); + return *this; + } + inline ResourceHandleProto_DtypeAndShape& operator=(ResourceHandleProto_DtypeAndShape&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ResourceHandleProto_DtypeAndShape& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ResourceHandleProto_DtypeAndShape* internal_default_instance() { + return reinterpret_cast( + &_ResourceHandleProto_DtypeAndShape_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(ResourceHandleProto_DtypeAndShape& a, ResourceHandleProto_DtypeAndShape& b) { + a.Swap(&b); + } + inline void Swap(ResourceHandleProto_DtypeAndShape* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ResourceHandleProto_DtypeAndShape* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ResourceHandleProto_DtypeAndShape* New() const final { + return CreateMaybeMessage(nullptr); + } + + ResourceHandleProto_DtypeAndShape* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ResourceHandleProto_DtypeAndShape& from); + void MergeFrom(const ResourceHandleProto_DtypeAndShape& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ResourceHandleProto_DtypeAndShape* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.ResourceHandleProto.DtypeAndShape"; + } + protected: + explicit ResourceHandleProto_DtypeAndShape(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_resource_5fhandle_2eproto); + return ::descriptor_table_resource_5fhandle_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kShapeFieldNumber = 2, + kDtypeFieldNumber = 1, + }; + // .tensorboard.TensorShapeProto shape = 2; + bool has_shape() const; + private: + bool _internal_has_shape() const; + public: + void clear_shape(); + const ::tensorboard::TensorShapeProto& shape() const; + ::tensorboard::TensorShapeProto* release_shape(); + ::tensorboard::TensorShapeProto* mutable_shape(); + void set_allocated_shape(::tensorboard::TensorShapeProto* shape); + private: + const ::tensorboard::TensorShapeProto& _internal_shape() const; + ::tensorboard::TensorShapeProto* _internal_mutable_shape(); + public: + void unsafe_arena_set_allocated_shape( + ::tensorboard::TensorShapeProto* shape); + ::tensorboard::TensorShapeProto* unsafe_arena_release_shape(); + + // .tensorboard.DataType dtype = 1; + void clear_dtype(); + ::tensorboard::DataType dtype() const; + void set_dtype(::tensorboard::DataType value); + private: + ::tensorboard::DataType _internal_dtype() const; + void _internal_set_dtype(::tensorboard::DataType value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.ResourceHandleProto.DtypeAndShape) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::tensorboard::TensorShapeProto* shape_; + int dtype_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_resource_5fhandle_2eproto; +}; +// ------------------------------------------------------------------- + +class ResourceHandleProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.ResourceHandleProto) */ { + public: + inline ResourceHandleProto() : ResourceHandleProto(nullptr) {} + virtual ~ResourceHandleProto(); + + ResourceHandleProto(const ResourceHandleProto& from); + ResourceHandleProto(ResourceHandleProto&& from) noexcept + : ResourceHandleProto() { + *this = ::std::move(from); + } + + inline ResourceHandleProto& operator=(const ResourceHandleProto& from) { + CopyFrom(from); + return *this; + } + inline ResourceHandleProto& operator=(ResourceHandleProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ResourceHandleProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ResourceHandleProto* internal_default_instance() { + return reinterpret_cast( + &_ResourceHandleProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(ResourceHandleProto& a, ResourceHandleProto& b) { + a.Swap(&b); + } + inline void Swap(ResourceHandleProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ResourceHandleProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ResourceHandleProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ResourceHandleProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ResourceHandleProto& from); + void MergeFrom(const ResourceHandleProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ResourceHandleProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.ResourceHandleProto"; + } + protected: + explicit ResourceHandleProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_resource_5fhandle_2eproto); + return ::descriptor_table_resource_5fhandle_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef ResourceHandleProto_DtypeAndShape DtypeAndShape; + + // accessors ------------------------------------------------------- + + enum : int { + kDtypesAndShapesFieldNumber = 6, + kDeviceFieldNumber = 1, + kContainerFieldNumber = 2, + kNameFieldNumber = 3, + kMaybeTypeNameFieldNumber = 5, + kHashCodeFieldNumber = 4, + }; + // repeated .tensorboard.ResourceHandleProto.DtypeAndShape dtypes_and_shapes = 6; + int dtypes_and_shapes_size() const; + private: + int _internal_dtypes_and_shapes_size() const; + public: + void clear_dtypes_and_shapes(); + ::tensorboard::ResourceHandleProto_DtypeAndShape* mutable_dtypes_and_shapes(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto_DtypeAndShape >* + mutable_dtypes_and_shapes(); + private: + const ::tensorboard::ResourceHandleProto_DtypeAndShape& _internal_dtypes_and_shapes(int index) const; + ::tensorboard::ResourceHandleProto_DtypeAndShape* _internal_add_dtypes_and_shapes(); + public: + const ::tensorboard::ResourceHandleProto_DtypeAndShape& dtypes_and_shapes(int index) const; + ::tensorboard::ResourceHandleProto_DtypeAndShape* add_dtypes_and_shapes(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto_DtypeAndShape >& + dtypes_and_shapes() const; + + // string device = 1; + void clear_device(); + const std::string& device() const; + void set_device(const std::string& value); + void set_device(std::string&& value); + void set_device(const char* value); + void set_device(const char* value, size_t size); + std::string* mutable_device(); + std::string* release_device(); + void set_allocated_device(std::string* device); + private: + const std::string& _internal_device() const; + void _internal_set_device(const std::string& value); + std::string* _internal_mutable_device(); + public: + + // string container = 2; + void clear_container(); + const std::string& container() const; + void set_container(const std::string& value); + void set_container(std::string&& value); + void set_container(const char* value); + void set_container(const char* value, size_t size); + std::string* mutable_container(); + std::string* release_container(); + void set_allocated_container(std::string* container); + private: + const std::string& _internal_container() const; + void _internal_set_container(const std::string& value); + std::string* _internal_mutable_container(); + public: + + // string name = 3; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // string maybe_type_name = 5; + void clear_maybe_type_name(); + const std::string& maybe_type_name() const; + void set_maybe_type_name(const std::string& value); + void set_maybe_type_name(std::string&& value); + void set_maybe_type_name(const char* value); + void set_maybe_type_name(const char* value, size_t size); + std::string* mutable_maybe_type_name(); + std::string* release_maybe_type_name(); + void set_allocated_maybe_type_name(std::string* maybe_type_name); + private: + const std::string& _internal_maybe_type_name() const; + void _internal_set_maybe_type_name(const std::string& value); + std::string* _internal_mutable_maybe_type_name(); + public: + + // uint64 hash_code = 4; + void clear_hash_code(); + ::PROTOBUF_NAMESPACE_ID::uint64 hash_code() const; + void set_hash_code(::PROTOBUF_NAMESPACE_ID::uint64 value); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_hash_code() const; + void _internal_set_hash_code(::PROTOBUF_NAMESPACE_ID::uint64 value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.ResourceHandleProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto_DtypeAndShape > dtypes_and_shapes_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr device_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr container_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr maybe_type_name_; + ::PROTOBUF_NAMESPACE_ID::uint64 hash_code_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_resource_5fhandle_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// ResourceHandleProto_DtypeAndShape + +// .tensorboard.DataType dtype = 1; +inline void ResourceHandleProto_DtypeAndShape::clear_dtype() { + dtype_ = 0; +} +inline ::tensorboard::DataType ResourceHandleProto_DtypeAndShape::_internal_dtype() const { + return static_cast< ::tensorboard::DataType >(dtype_); +} +inline ::tensorboard::DataType ResourceHandleProto_DtypeAndShape::dtype() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.DtypeAndShape.dtype) + return _internal_dtype(); +} +inline void ResourceHandleProto_DtypeAndShape::_internal_set_dtype(::tensorboard::DataType value) { + + dtype_ = value; +} +inline void ResourceHandleProto_DtypeAndShape::set_dtype(::tensorboard::DataType value) { + _internal_set_dtype(value); + // @@protoc_insertion_point(field_set:tensorboard.ResourceHandleProto.DtypeAndShape.dtype) +} + +// .tensorboard.TensorShapeProto shape = 2; +inline bool ResourceHandleProto_DtypeAndShape::_internal_has_shape() const { + return this != internal_default_instance() && shape_ != nullptr; +} +inline bool ResourceHandleProto_DtypeAndShape::has_shape() const { + return _internal_has_shape(); +} +inline const ::tensorboard::TensorShapeProto& ResourceHandleProto_DtypeAndShape::_internal_shape() const { + const ::tensorboard::TensorShapeProto* p = shape_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_TensorShapeProto_default_instance_); +} +inline const ::tensorboard::TensorShapeProto& ResourceHandleProto_DtypeAndShape::shape() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.DtypeAndShape.shape) + return _internal_shape(); +} +inline void ResourceHandleProto_DtypeAndShape::unsafe_arena_set_allocated_shape( + ::tensorboard::TensorShapeProto* shape) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(shape_); + } + shape_ = shape; + if (shape) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.ResourceHandleProto.DtypeAndShape.shape) +} +inline ::tensorboard::TensorShapeProto* ResourceHandleProto_DtypeAndShape::release_shape() { + + ::tensorboard::TensorShapeProto* temp = shape_; + shape_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::TensorShapeProto* ResourceHandleProto_DtypeAndShape::unsafe_arena_release_shape() { + // @@protoc_insertion_point(field_release:tensorboard.ResourceHandleProto.DtypeAndShape.shape) + + ::tensorboard::TensorShapeProto* temp = shape_; + shape_ = nullptr; + return temp; +} +inline ::tensorboard::TensorShapeProto* ResourceHandleProto_DtypeAndShape::_internal_mutable_shape() { + + if (shape_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::TensorShapeProto>(GetArena()); + shape_ = p; + } + return shape_; +} +inline ::tensorboard::TensorShapeProto* ResourceHandleProto_DtypeAndShape::mutable_shape() { + // @@protoc_insertion_point(field_mutable:tensorboard.ResourceHandleProto.DtypeAndShape.shape) + return _internal_mutable_shape(); +} +inline void ResourceHandleProto_DtypeAndShape::set_allocated_shape(::tensorboard::TensorShapeProto* shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete reinterpret_cast< ::PROTOBUF_NAMESPACE_ID::MessageLite*>(shape_); + } + if (shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(shape)->GetArena(); + if (message_arena != submessage_arena) { + shape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, shape, submessage_arena); + } + + } else { + + } + shape_ = shape; + // @@protoc_insertion_point(field_set_allocated:tensorboard.ResourceHandleProto.DtypeAndShape.shape) +} + +// ------------------------------------------------------------------- + +// ResourceHandleProto + +// string device = 1; +inline void ResourceHandleProto::clear_device() { + device_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& ResourceHandleProto::device() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.device) + return _internal_device(); +} +inline void ResourceHandleProto::set_device(const std::string& value) { + _internal_set_device(value); + // @@protoc_insertion_point(field_set:tensorboard.ResourceHandleProto.device) +} +inline std::string* ResourceHandleProto::mutable_device() { + // @@protoc_insertion_point(field_mutable:tensorboard.ResourceHandleProto.device) + return _internal_mutable_device(); +} +inline const std::string& ResourceHandleProto::_internal_device() const { + return device_.Get(); +} +inline void ResourceHandleProto::_internal_set_device(const std::string& value) { + + device_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ResourceHandleProto::set_device(std::string&& value) { + + device_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.ResourceHandleProto.device) +} +inline void ResourceHandleProto::set_device(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + device_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.ResourceHandleProto.device) +} +inline void ResourceHandleProto::set_device(const char* value, + size_t size) { + + device_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.ResourceHandleProto.device) +} +inline std::string* ResourceHandleProto::_internal_mutable_device() { + + return device_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ResourceHandleProto::release_device() { + // @@protoc_insertion_point(field_release:tensorboard.ResourceHandleProto.device) + return device_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ResourceHandleProto::set_allocated_device(std::string* device) { + if (device != nullptr) { + + } else { + + } + device_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), device, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.ResourceHandleProto.device) +} + +// string container = 2; +inline void ResourceHandleProto::clear_container() { + container_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& ResourceHandleProto::container() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.container) + return _internal_container(); +} +inline void ResourceHandleProto::set_container(const std::string& value) { + _internal_set_container(value); + // @@protoc_insertion_point(field_set:tensorboard.ResourceHandleProto.container) +} +inline std::string* ResourceHandleProto::mutable_container() { + // @@protoc_insertion_point(field_mutable:tensorboard.ResourceHandleProto.container) + return _internal_mutable_container(); +} +inline const std::string& ResourceHandleProto::_internal_container() const { + return container_.Get(); +} +inline void ResourceHandleProto::_internal_set_container(const std::string& value) { + + container_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ResourceHandleProto::set_container(std::string&& value) { + + container_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.ResourceHandleProto.container) +} +inline void ResourceHandleProto::set_container(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + container_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.ResourceHandleProto.container) +} +inline void ResourceHandleProto::set_container(const char* value, + size_t size) { + + container_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.ResourceHandleProto.container) +} +inline std::string* ResourceHandleProto::_internal_mutable_container() { + + return container_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ResourceHandleProto::release_container() { + // @@protoc_insertion_point(field_release:tensorboard.ResourceHandleProto.container) + return container_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ResourceHandleProto::set_allocated_container(std::string* container) { + if (container != nullptr) { + + } else { + + } + container_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), container, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.ResourceHandleProto.container) +} + +// string name = 3; +inline void ResourceHandleProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& ResourceHandleProto::name() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.name) + return _internal_name(); +} +inline void ResourceHandleProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:tensorboard.ResourceHandleProto.name) +} +inline std::string* ResourceHandleProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.ResourceHandleProto.name) + return _internal_mutable_name(); +} +inline const std::string& ResourceHandleProto::_internal_name() const { + return name_.Get(); +} +inline void ResourceHandleProto::_internal_set_name(const std::string& value) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ResourceHandleProto::set_name(std::string&& value) { + + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.ResourceHandleProto.name) +} +inline void ResourceHandleProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.ResourceHandleProto.name) +} +inline void ResourceHandleProto::set_name(const char* value, + size_t size) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.ResourceHandleProto.name) +} +inline std::string* ResourceHandleProto::_internal_mutable_name() { + + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ResourceHandleProto::release_name() { + // @@protoc_insertion_point(field_release:tensorboard.ResourceHandleProto.name) + return name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ResourceHandleProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + + } else { + + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.ResourceHandleProto.name) +} + +// uint64 hash_code = 4; +inline void ResourceHandleProto::clear_hash_code() { + hash_code_ = PROTOBUF_ULONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 ResourceHandleProto::_internal_hash_code() const { + return hash_code_; +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 ResourceHandleProto::hash_code() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.hash_code) + return _internal_hash_code(); +} +inline void ResourceHandleProto::_internal_set_hash_code(::PROTOBUF_NAMESPACE_ID::uint64 value) { + + hash_code_ = value; +} +inline void ResourceHandleProto::set_hash_code(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_set_hash_code(value); + // @@protoc_insertion_point(field_set:tensorboard.ResourceHandleProto.hash_code) +} + +// string maybe_type_name = 5; +inline void ResourceHandleProto::clear_maybe_type_name() { + maybe_type_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& ResourceHandleProto::maybe_type_name() const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.maybe_type_name) + return _internal_maybe_type_name(); +} +inline void ResourceHandleProto::set_maybe_type_name(const std::string& value) { + _internal_set_maybe_type_name(value); + // @@protoc_insertion_point(field_set:tensorboard.ResourceHandleProto.maybe_type_name) +} +inline std::string* ResourceHandleProto::mutable_maybe_type_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.ResourceHandleProto.maybe_type_name) + return _internal_mutable_maybe_type_name(); +} +inline const std::string& ResourceHandleProto::_internal_maybe_type_name() const { + return maybe_type_name_.Get(); +} +inline void ResourceHandleProto::_internal_set_maybe_type_name(const std::string& value) { + + maybe_type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ResourceHandleProto::set_maybe_type_name(std::string&& value) { + + maybe_type_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.ResourceHandleProto.maybe_type_name) +} +inline void ResourceHandleProto::set_maybe_type_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + maybe_type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.ResourceHandleProto.maybe_type_name) +} +inline void ResourceHandleProto::set_maybe_type_name(const char* value, + size_t size) { + + maybe_type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.ResourceHandleProto.maybe_type_name) +} +inline std::string* ResourceHandleProto::_internal_mutable_maybe_type_name() { + + return maybe_type_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ResourceHandleProto::release_maybe_type_name() { + // @@protoc_insertion_point(field_release:tensorboard.ResourceHandleProto.maybe_type_name) + return maybe_type_name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ResourceHandleProto::set_allocated_maybe_type_name(std::string* maybe_type_name) { + if (maybe_type_name != nullptr) { + + } else { + + } + maybe_type_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), maybe_type_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.ResourceHandleProto.maybe_type_name) +} + +// repeated .tensorboard.ResourceHandleProto.DtypeAndShape dtypes_and_shapes = 6; +inline int ResourceHandleProto::_internal_dtypes_and_shapes_size() const { + return dtypes_and_shapes_.size(); +} +inline int ResourceHandleProto::dtypes_and_shapes_size() const { + return _internal_dtypes_and_shapes_size(); +} +inline void ResourceHandleProto::clear_dtypes_and_shapes() { + dtypes_and_shapes_.Clear(); +} +inline ::tensorboard::ResourceHandleProto_DtypeAndShape* ResourceHandleProto::mutable_dtypes_and_shapes(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.ResourceHandleProto.dtypes_and_shapes) + return dtypes_and_shapes_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto_DtypeAndShape >* +ResourceHandleProto::mutable_dtypes_and_shapes() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.ResourceHandleProto.dtypes_and_shapes) + return &dtypes_and_shapes_; +} +inline const ::tensorboard::ResourceHandleProto_DtypeAndShape& ResourceHandleProto::_internal_dtypes_and_shapes(int index) const { + return dtypes_and_shapes_.Get(index); +} +inline const ::tensorboard::ResourceHandleProto_DtypeAndShape& ResourceHandleProto::dtypes_and_shapes(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.ResourceHandleProto.dtypes_and_shapes) + return _internal_dtypes_and_shapes(index); +} +inline ::tensorboard::ResourceHandleProto_DtypeAndShape* ResourceHandleProto::_internal_add_dtypes_and_shapes() { + return dtypes_and_shapes_.Add(); +} +inline ::tensorboard::ResourceHandleProto_DtypeAndShape* ResourceHandleProto::add_dtypes_and_shapes() { + // @@protoc_insertion_point(field_add:tensorboard.ResourceHandleProto.dtypes_and_shapes) + return _internal_add_dtypes_and_shapes(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto_DtypeAndShape >& +ResourceHandleProto::dtypes_and_shapes() const { + // @@protoc_insertion_point(field_list:tensorboard.ResourceHandleProto.dtypes_and_shapes) + return dtypes_and_shapes_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_resource_5fhandle_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/resource_handle.proto b/plugins/mindstudio-insight-plugins/proto/resource_handle.proto new file mode 100644 index 0000000000000000000000000000000000000000..084a11cc6f44a6c7d2996ac1e943b132ed02bb8f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/resource_handle.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +package tensorboard; + +import "tensor_shape.proto"; +import "types.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/resource_handle_go_proto"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; + + // Protocol buffer representing a pair of (data type, tensor shape). + message DtypeAndShape { + // Data type of the tensor. + DataType dtype = 1; + // Shape of the tensor. + TensorShapeProto shape = 2; + } + + // Data types and shapes for the underlying resource. + repeated DtypeAndShape dtypes_and_shapes = 6; + + reserved 7; +} diff --git a/plugins/mindstudio-insight-plugins/proto/summary.pb.cc b/plugins/mindstudio-insight-plugins/proto/summary.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..544c6b0956ca4a0cb5f1bc14be795660c56d3569 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/summary.pb.cc @@ -0,0 +1,2582 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: summary.proto + +#include "summary.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_histogram_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_HistogramProto_histogram_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_summary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Audio_summary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_summary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Image_summary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_summary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<5> scc_info_Summary_Value_summary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_summary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_SummaryMetadata_summary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_summary_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SummaryMetadata_PluginData_summary_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_tensor_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_TensorProto_tensor_2eproto; +namespace tensorboard { +class SummaryDescriptionDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _SummaryDescription_default_instance_; +class SummaryMetadata_PluginDataDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _SummaryMetadata_PluginData_default_instance_; +class SummaryMetadataDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _SummaryMetadata_default_instance_; +class Summary_ImageDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_Image_default_instance_; +class Summary_AudioDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_Audio_default_instance_; +class Summary_ValueDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + float simple_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr obsolete_old_style_histogram_; + const ::tensorboard::Summary_Image* image_; + const ::tensorboard::HistogramProto* histo_; + const ::tensorboard::Summary_Audio* audio_; + const ::tensorboard::TensorProto* tensor_; +} _Summary_Value_default_instance_; +class SummaryDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _Summary_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_Summary_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_Summary_default_instance_; + new (ptr) ::tensorboard::Summary(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::Summary::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_Summary_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_Summary_summary_2eproto}, { + &scc_info_Summary_Value_summary_2eproto.base,}}; + +static void InitDefaultsscc_info_Summary_Audio_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_Summary_Audio_default_instance_; + new (ptr) ::tensorboard::Summary_Audio(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::Summary_Audio::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Audio_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Summary_Audio_summary_2eproto}, {}}; + +static void InitDefaultsscc_info_Summary_Image_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_Summary_Image_default_instance_; + new (ptr) ::tensorboard::Summary_Image(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::Summary_Image::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_Summary_Image_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_Summary_Image_summary_2eproto}, {}}; + +static void InitDefaultsscc_info_Summary_Value_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_Summary_Value_default_instance_; + new (ptr) ::tensorboard::Summary_Value(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::Summary_Value::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<5> scc_info_Summary_Value_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 5, 0, InitDefaultsscc_info_Summary_Value_summary_2eproto}, { + &scc_info_SummaryMetadata_summary_2eproto.base, + &scc_info_Summary_Image_summary_2eproto.base, + &scc_info_HistogramProto_histogram_2eproto.base, + &scc_info_Summary_Audio_summary_2eproto.base, + &scc_info_TensorProto_tensor_2eproto.base,}}; + +static void InitDefaultsscc_info_SummaryDescription_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_SummaryDescription_default_instance_; + new (ptr) ::tensorboard::SummaryDescription(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::SummaryDescription::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SummaryDescription_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_SummaryDescription_summary_2eproto}, {}}; + +static void InitDefaultsscc_info_SummaryMetadata_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_SummaryMetadata_default_instance_; + new (ptr) ::tensorboard::SummaryMetadata(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::SummaryMetadata::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_SummaryMetadata_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_SummaryMetadata_summary_2eproto}, { + &scc_info_SummaryMetadata_PluginData_summary_2eproto.base,}}; + +static void InitDefaultsscc_info_SummaryMetadata_PluginData_summary_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_SummaryMetadata_PluginData_default_instance_; + new (ptr) ::tensorboard::SummaryMetadata_PluginData(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::SummaryMetadata_PluginData::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SummaryMetadata_PluginData_summary_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_SummaryMetadata_PluginData_summary_2eproto}, {}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_summary_2eproto[7]; +static const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* file_level_enum_descriptors_summary_2eproto[1]; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_summary_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_summary_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryDescription, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryDescription, type_hint_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata_PluginData, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata_PluginData, plugin_name_), + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata_PluginData, content_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata, plugin_data_), + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata, display_name_), + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata, summary_description_), + PROTOBUF_FIELD_OFFSET(::tensorboard::SummaryMetadata, data_class_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Image, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Image, height_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Image, width_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Image, colorspace_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Image, encoded_image_string_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Audio, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Audio, sample_rate_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Audio, num_channels_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Audio, length_frames_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Audio, encoded_audio_string_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Audio, content_type_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Value, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Value, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Value, node_name_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Value, tag_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Value, metadata_), + offsetof(::tensorboard::Summary_ValueDefaultTypeInternal, simple_value_), + offsetof(::tensorboard::Summary_ValueDefaultTypeInternal, obsolete_old_style_histogram_), + offsetof(::tensorboard::Summary_ValueDefaultTypeInternal, image_), + offsetof(::tensorboard::Summary_ValueDefaultTypeInternal, histo_), + offsetof(::tensorboard::Summary_ValueDefaultTypeInternal, audio_), + offsetof(::tensorboard::Summary_ValueDefaultTypeInternal, tensor_), + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary_Value, value_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::Summary, value_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::SummaryDescription)}, + { 6, -1, sizeof(::tensorboard::SummaryMetadata_PluginData)}, + { 13, -1, sizeof(::tensorboard::SummaryMetadata)}, + { 22, -1, sizeof(::tensorboard::Summary_Image)}, + { 31, -1, sizeof(::tensorboard::Summary_Audio)}, + { 41, -1, sizeof(::tensorboard::Summary_Value)}, + { 56, -1, sizeof(::tensorboard::Summary)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_SummaryDescription_default_instance_), + reinterpret_cast(&::tensorboard::_SummaryMetadata_PluginData_default_instance_), + reinterpret_cast(&::tensorboard::_SummaryMetadata_default_instance_), + reinterpret_cast(&::tensorboard::_Summary_Image_default_instance_), + reinterpret_cast(&::tensorboard::_Summary_Audio_default_instance_), + reinterpret_cast(&::tensorboard::_Summary_Value_default_instance_), + reinterpret_cast(&::tensorboard::_Summary_default_instance_), +}; + +const char descriptor_table_protodef_summary_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\rsummary.proto\022\013tensorboard\032\017histogram." + "proto\032\014tensor.proto\"\'\n\022SummaryDescriptio" + "n\022\021\n\ttype_hint\030\001 \001(\t\"\342\001\n\017SummaryMetadata" + "\022<\n\013plugin_data\030\001 \001(\0132\'.tensorboard.Summ" + "aryMetadata.PluginData\022\024\n\014display_name\030\002" + " \001(\t\022\033\n\023summary_description\030\003 \001(\t\022*\n\ndat" + "a_class\030\004 \001(\0162\026.tensorboard.DataClass\0322\n" + "\nPluginData\022\023\n\013plugin_name\030\001 \001(\t\022\017\n\007cont" + "ent\030\002 \001(\014\"\344\004\n\007Summary\022)\n\005value\030\001 \003(\0132\032.t" + "ensorboard.Summary.Value\032X\n\005Image\022\016\n\006hei" + "ght\030\001 \001(\005\022\r\n\005width\030\002 \001(\005\022\022\n\ncolorspace\030\003" + " \001(\005\022\034\n\024encoded_image_string\030\004 \001(\014\032}\n\005Au" + "dio\022\023\n\013sample_rate\030\001 \001(\002\022\024\n\014num_channels" + "\030\002 \001(\003\022\025\n\rlength_frames\030\003 \001(\003\022\034\n\024encoded" + "_audio_string\030\004 \001(\014\022\024\n\014content_type\030\005 \001(" + "\t\032\324\002\n\005Value\022\021\n\tnode_name\030\007 \001(\t\022\013\n\003tag\030\001 " + "\001(\t\022.\n\010metadata\030\t \001(\0132\034.tensorboard.Summ" + "aryMetadata\022\026\n\014simple_value\030\002 \001(\002H\000\022&\n\034o" + "bsolete_old_style_histogram\030\003 \001(\014H\000\022+\n\005i" + "mage\030\004 \001(\0132\032.tensorboard.Summary.ImageH\000" + "\022,\n\005histo\030\005 \001(\0132\033.tensorboard.HistogramP" + "rotoH\000\022+\n\005audio\030\006 \001(\0132\032.tensorboard.Summ" + "ary.AudioH\000\022*\n\006tensor\030\010 \001(\0132\030.tensorboar" + "d.TensorProtoH\000B\007\n\005value*o\n\tDataClass\022\026\n" + "\022DATA_CLASS_UNKNOWN\020\000\022\025\n\021DATA_CLASS_SCAL" + "AR\020\001\022\025\n\021DATA_CLASS_TENSOR\020\002\022\034\n\030DATA_CLAS" + "S_BLOB_SEQUENCE\020\003B~\n\030org.tensorflow.fram" + "eworkB\rSummaryProtosP\001ZNgithub.com/tenso" + "rflow/tensorflow/tensorflow/go/core/fram" + "ework/summary_go_proto\370\001\001P\000b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_summary_2eproto_deps[2] = { + &::descriptor_table_histogram_2eproto, + &::descriptor_table_tensor_2eproto, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_summary_2eproto_sccs[7] = { + &scc_info_Summary_summary_2eproto.base, + &scc_info_Summary_Audio_summary_2eproto.base, + &scc_info_Summary_Image_summary_2eproto.base, + &scc_info_Summary_Value_summary_2eproto.base, + &scc_info_SummaryDescription_summary_2eproto.base, + &scc_info_SummaryMetadata_summary_2eproto.base, + &scc_info_SummaryMetadata_PluginData_summary_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_summary_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_summary_2eproto = { + false, false, descriptor_table_protodef_summary_2eproto, "summary.proto", 1195, + &descriptor_table_summary_2eproto_once, descriptor_table_summary_2eproto_sccs, descriptor_table_summary_2eproto_deps, 7, 2, + schemas, file_default_instances, TableStruct_summary_2eproto::offsets, + file_level_metadata_summary_2eproto, 7, file_level_enum_descriptors_summary_2eproto, file_level_service_descriptors_summary_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_summary_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_summary_2eproto)), true); +namespace tensorboard { +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataClass_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_summary_2eproto); + return file_level_enum_descriptors_summary_2eproto[0]; +} +bool DataClass_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + return true; + default: + return false; + } +} + + +// =================================================================== + +void SummaryDescription::InitAsDefaultInstance() { +} +class SummaryDescription::_Internal { + public: +}; + +SummaryDescription::SummaryDescription(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.SummaryDescription) +} +SummaryDescription::SummaryDescription(const SummaryDescription& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + type_hint_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_type_hint().empty()) { + type_hint_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_type_hint(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:tensorboard.SummaryDescription) +} + +void SummaryDescription::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SummaryDescription_summary_2eproto.base); + type_hint_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +SummaryDescription::~SummaryDescription() { + // @@protoc_insertion_point(destructor:tensorboard.SummaryDescription) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void SummaryDescription::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + type_hint_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void SummaryDescription::ArenaDtor(void* object) { + SummaryDescription* _this = reinterpret_cast< SummaryDescription* >(object); + (void)_this; +} +void SummaryDescription::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void SummaryDescription::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const SummaryDescription& SummaryDescription::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SummaryDescription_summary_2eproto.base); + return *internal_default_instance(); +} + + +void SummaryDescription::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.SummaryDescription) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + type_hint_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* SummaryDescription::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string type_hint = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_type_hint(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SummaryDescription.type_hint")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* SummaryDescription::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.SummaryDescription) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string type_hint = 1; + if (this->type_hint().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_type_hint().data(), static_cast(this->_internal_type_hint().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SummaryDescription.type_hint"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_type_hint(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.SummaryDescription) + return target; +} + +size_t SummaryDescription::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.SummaryDescription) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string type_hint = 1; + if (this->type_hint().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_type_hint()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void SummaryDescription::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.SummaryDescription) + GOOGLE_DCHECK_NE(&from, this); + const SummaryDescription* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.SummaryDescription) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.SummaryDescription) + MergeFrom(*source); + } +} + +void SummaryDescription::MergeFrom(const SummaryDescription& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.SummaryDescription) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.type_hint().size() > 0) { + _internal_set_type_hint(from._internal_type_hint()); + } +} + +void SummaryDescription::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.SummaryDescription) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SummaryDescription::CopyFrom(const SummaryDescription& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.SummaryDescription) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SummaryDescription::IsInitialized() const { + return true; +} + +void SummaryDescription::InternalSwap(SummaryDescription* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + type_hint_.Swap(&other->type_hint_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata SummaryDescription::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void SummaryMetadata_PluginData::InitAsDefaultInstance() { +} +class SummaryMetadata_PluginData::_Internal { + public: +}; + +SummaryMetadata_PluginData::SummaryMetadata_PluginData(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.SummaryMetadata.PluginData) +} +SummaryMetadata_PluginData::SummaryMetadata_PluginData(const SummaryMetadata_PluginData& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + plugin_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_plugin_name().empty()) { + plugin_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_plugin_name(), + GetArena()); + } + content_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_content().empty()) { + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_content(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:tensorboard.SummaryMetadata.PluginData) +} + +void SummaryMetadata_PluginData::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SummaryMetadata_PluginData_summary_2eproto.base); + plugin_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + content_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +SummaryMetadata_PluginData::~SummaryMetadata_PluginData() { + // @@protoc_insertion_point(destructor:tensorboard.SummaryMetadata.PluginData) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void SummaryMetadata_PluginData::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + plugin_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + content_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void SummaryMetadata_PluginData::ArenaDtor(void* object) { + SummaryMetadata_PluginData* _this = reinterpret_cast< SummaryMetadata_PluginData* >(object); + (void)_this; +} +void SummaryMetadata_PluginData::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void SummaryMetadata_PluginData::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const SummaryMetadata_PluginData& SummaryMetadata_PluginData::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SummaryMetadata_PluginData_summary_2eproto.base); + return *internal_default_instance(); +} + + +void SummaryMetadata_PluginData::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.SummaryMetadata.PluginData) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + plugin_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + content_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* SummaryMetadata_PluginData::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string plugin_name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_plugin_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SummaryMetadata.PluginData.plugin_name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes content = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_content(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* SummaryMetadata_PluginData::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.SummaryMetadata.PluginData) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string plugin_name = 1; + if (this->plugin_name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_plugin_name().data(), static_cast(this->_internal_plugin_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SummaryMetadata.PluginData.plugin_name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_plugin_name(), target); + } + + // bytes content = 2; + if (this->content().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 2, this->_internal_content(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.SummaryMetadata.PluginData) + return target; +} + +size_t SummaryMetadata_PluginData::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.SummaryMetadata.PluginData) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string plugin_name = 1; + if (this->plugin_name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_plugin_name()); + } + + // bytes content = 2; + if (this->content().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_content()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void SummaryMetadata_PluginData::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.SummaryMetadata.PluginData) + GOOGLE_DCHECK_NE(&from, this); + const SummaryMetadata_PluginData* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.SummaryMetadata.PluginData) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.SummaryMetadata.PluginData) + MergeFrom(*source); + } +} + +void SummaryMetadata_PluginData::MergeFrom(const SummaryMetadata_PluginData& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.SummaryMetadata.PluginData) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.plugin_name().size() > 0) { + _internal_set_plugin_name(from._internal_plugin_name()); + } + if (from.content().size() > 0) { + _internal_set_content(from._internal_content()); + } +} + +void SummaryMetadata_PluginData::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.SummaryMetadata.PluginData) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SummaryMetadata_PluginData::CopyFrom(const SummaryMetadata_PluginData& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.SummaryMetadata.PluginData) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SummaryMetadata_PluginData::IsInitialized() const { + return true; +} + +void SummaryMetadata_PluginData::InternalSwap(SummaryMetadata_PluginData* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + plugin_name_.Swap(&other->plugin_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + content_.Swap(&other->content_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata SummaryMetadata_PluginData::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void SummaryMetadata::InitAsDefaultInstance() { + ::tensorboard::_SummaryMetadata_default_instance_._instance.get_mutable()->plugin_data_ = const_cast< ::tensorboard::SummaryMetadata_PluginData*>( + ::tensorboard::SummaryMetadata_PluginData::internal_default_instance()); +} +class SummaryMetadata::_Internal { + public: + static const ::tensorboard::SummaryMetadata_PluginData& plugin_data(const SummaryMetadata* msg); +}; + +const ::tensorboard::SummaryMetadata_PluginData& +SummaryMetadata::_Internal::plugin_data(const SummaryMetadata* msg) { + return *msg->plugin_data_; +} +SummaryMetadata::SummaryMetadata(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.SummaryMetadata) +} +SummaryMetadata::SummaryMetadata(const SummaryMetadata& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + display_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_display_name().empty()) { + display_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_display_name(), + GetArena()); + } + summary_description_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_summary_description().empty()) { + summary_description_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_summary_description(), + GetArena()); + } + if (from._internal_has_plugin_data()) { + plugin_data_ = new ::tensorboard::SummaryMetadata_PluginData(*from.plugin_data_); + } else { + plugin_data_ = nullptr; + } + data_class_ = from.data_class_; + // @@protoc_insertion_point(copy_constructor:tensorboard.SummaryMetadata) +} + +void SummaryMetadata::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SummaryMetadata_summary_2eproto.base); + display_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + summary_description_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&plugin_data_, 0, static_cast( + reinterpret_cast(&data_class_) - + reinterpret_cast(&plugin_data_)) + sizeof(data_class_)); +} + +SummaryMetadata::~SummaryMetadata() { + // @@protoc_insertion_point(destructor:tensorboard.SummaryMetadata) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void SummaryMetadata::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + display_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + summary_description_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete plugin_data_; +} + +void SummaryMetadata::ArenaDtor(void* object) { + SummaryMetadata* _this = reinterpret_cast< SummaryMetadata* >(object); + (void)_this; +} +void SummaryMetadata::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void SummaryMetadata::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const SummaryMetadata& SummaryMetadata::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SummaryMetadata_summary_2eproto.base); + return *internal_default_instance(); +} + + +void SummaryMetadata::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.SummaryMetadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + display_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + summary_description_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + if (GetArena() == nullptr && plugin_data_ != nullptr) { + delete plugin_data_; + } + plugin_data_ = nullptr; + data_class_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* SummaryMetadata::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.SummaryMetadata.PluginData plugin_data = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr = ctx->ParseMessage(_internal_mutable_plugin_data(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string display_name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_display_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SummaryMetadata.display_name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string summary_description = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_summary_description(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.SummaryMetadata.summary_description")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.DataClass data_class = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_data_class(static_cast<::tensorboard::DataClass>(val)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* SummaryMetadata::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.SummaryMetadata) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.SummaryMetadata.PluginData plugin_data = 1; + if (this->has_plugin_data()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 1, _Internal::plugin_data(this), target, stream); + } + + // string display_name = 2; + if (this->display_name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_display_name().data(), static_cast(this->_internal_display_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SummaryMetadata.display_name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_display_name(), target); + } + + // string summary_description = 3; + if (this->summary_description().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_summary_description().data(), static_cast(this->_internal_summary_description().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.SummaryMetadata.summary_description"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_summary_description(), target); + } + + // .tensorboard.DataClass data_class = 4; + if (this->data_class() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 4, this->_internal_data_class(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.SummaryMetadata) + return target; +} + +size_t SummaryMetadata::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.SummaryMetadata) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string display_name = 2; + if (this->display_name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_display_name()); + } + + // string summary_description = 3; + if (this->summary_description().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_summary_description()); + } + + // .tensorboard.SummaryMetadata.PluginData plugin_data = 1; + if (this->has_plugin_data()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *plugin_data_); + } + + // .tensorboard.DataClass data_class = 4; + if (this->data_class() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_data_class()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void SummaryMetadata::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.SummaryMetadata) + GOOGLE_DCHECK_NE(&from, this); + const SummaryMetadata* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.SummaryMetadata) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.SummaryMetadata) + MergeFrom(*source); + } +} + +void SummaryMetadata::MergeFrom(const SummaryMetadata& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.SummaryMetadata) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.display_name().size() > 0) { + _internal_set_display_name(from._internal_display_name()); + } + if (from.summary_description().size() > 0) { + _internal_set_summary_description(from._internal_summary_description()); + } + if (from.has_plugin_data()) { + _internal_mutable_plugin_data()->::tensorboard::SummaryMetadata_PluginData::MergeFrom(from._internal_plugin_data()); + } + if (from.data_class() != 0) { + _internal_set_data_class(from._internal_data_class()); + } +} + +void SummaryMetadata::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.SummaryMetadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SummaryMetadata::CopyFrom(const SummaryMetadata& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.SummaryMetadata) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SummaryMetadata::IsInitialized() const { + return true; +} + +void SummaryMetadata::InternalSwap(SummaryMetadata* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + display_name_.Swap(&other->display_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + summary_description_.Swap(&other->summary_description_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(SummaryMetadata, data_class_) + + sizeof(SummaryMetadata::data_class_) + - PROTOBUF_FIELD_OFFSET(SummaryMetadata, plugin_data_)>( + reinterpret_cast(&plugin_data_), + reinterpret_cast(&other->plugin_data_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata SummaryMetadata::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Image::InitAsDefaultInstance() { +} +class Summary_Image::_Internal { + public: +}; + +Summary_Image::Summary_Image(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.Summary.Image) +} +Summary_Image::Summary_Image(const Summary_Image& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + encoded_image_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_encoded_image_string().empty()) { + encoded_image_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_encoded_image_string(), + GetArena()); + } + ::memcpy(&height_, &from.height_, + static_cast(reinterpret_cast(&colorspace_) - + reinterpret_cast(&height_)) + sizeof(colorspace_)); + // @@protoc_insertion_point(copy_constructor:tensorboard.Summary.Image) +} + +void Summary_Image::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_Image_summary_2eproto.base); + encoded_image_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&height_, 0, static_cast( + reinterpret_cast(&colorspace_) - + reinterpret_cast(&height_)) + sizeof(colorspace_)); +} + +Summary_Image::~Summary_Image() { + // @@protoc_insertion_point(destructor:tensorboard.Summary.Image) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Image::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + encoded_image_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void Summary_Image::ArenaDtor(void* object) { + Summary_Image* _this = reinterpret_cast< Summary_Image* >(object); + (void)_this; +} +void Summary_Image::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Image::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Image& Summary_Image::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Image_summary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Image::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.Summary.Image) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + encoded_image_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::memset(&height_, 0, static_cast( + reinterpret_cast(&colorspace_) - + reinterpret_cast(&height_)) + sizeof(colorspace_)); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Image::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // int32 height = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + height_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // int32 width = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + width_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // int32 colorspace = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + colorspace_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes encoded_image_string = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_encoded_image_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Image::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.Summary.Image) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int32 height = 1; + if (this->height() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(1, this->_internal_height(), target); + } + + // int32 width = 2; + if (this->width() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(2, this->_internal_width(), target); + } + + // int32 colorspace = 3; + if (this->colorspace() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(3, this->_internal_colorspace(), target); + } + + // bytes encoded_image_string = 4; + if (this->encoded_image_string().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 4, this->_internal_encoded_image_string(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.Summary.Image) + return target; +} + +size_t Summary_Image::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.Summary.Image) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // bytes encoded_image_string = 4; + if (this->encoded_image_string().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_encoded_image_string()); + } + + // int32 height = 1; + if (this->height() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_height()); + } + + // int32 width = 2; + if (this->width() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_width()); + } + + // int32 colorspace = 3; + if (this->colorspace() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_colorspace()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Image::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.Summary.Image) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Image* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.Summary.Image) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.Summary.Image) + MergeFrom(*source); + } +} + +void Summary_Image::MergeFrom(const Summary_Image& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.Summary.Image) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.encoded_image_string().size() > 0) { + _internal_set_encoded_image_string(from._internal_encoded_image_string()); + } + if (from.height() != 0) { + _internal_set_height(from._internal_height()); + } + if (from.width() != 0) { + _internal_set_width(from._internal_width()); + } + if (from.colorspace() != 0) { + _internal_set_colorspace(from._internal_colorspace()); + } +} + +void Summary_Image::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.Summary.Image) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Image::CopyFrom(const Summary_Image& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.Summary.Image) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Image::IsInitialized() const { + return true; +} + +void Summary_Image::InternalSwap(Summary_Image* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + encoded_image_string_.Swap(&other->encoded_image_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Summary_Image, colorspace_) + + sizeof(Summary_Image::colorspace_) + - PROTOBUF_FIELD_OFFSET(Summary_Image, height_)>( + reinterpret_cast(&height_), + reinterpret_cast(&other->height_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Image::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Audio::InitAsDefaultInstance() { +} +class Summary_Audio::_Internal { + public: +}; + +Summary_Audio::Summary_Audio(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.Summary.Audio) +} +Summary_Audio::Summary_Audio(const Summary_Audio& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + encoded_audio_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_encoded_audio_string().empty()) { + encoded_audio_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_encoded_audio_string(), + GetArena()); + } + content_type_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_content_type().empty()) { + content_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_content_type(), + GetArena()); + } + ::memcpy(&num_channels_, &from.num_channels_, + static_cast(reinterpret_cast(&sample_rate_) - + reinterpret_cast(&num_channels_)) + sizeof(sample_rate_)); + // @@protoc_insertion_point(copy_constructor:tensorboard.Summary.Audio) +} + +void Summary_Audio::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_Audio_summary_2eproto.base); + encoded_audio_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + content_type_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&num_channels_, 0, static_cast( + reinterpret_cast(&sample_rate_) - + reinterpret_cast(&num_channels_)) + sizeof(sample_rate_)); +} + +Summary_Audio::~Summary_Audio() { + // @@protoc_insertion_point(destructor:tensorboard.Summary.Audio) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Audio::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + encoded_audio_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + content_type_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void Summary_Audio::ArenaDtor(void* object) { + Summary_Audio* _this = reinterpret_cast< Summary_Audio* >(object); + (void)_this; +} +void Summary_Audio::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Audio::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Audio& Summary_Audio::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Audio_summary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Audio::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.Summary.Audio) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + encoded_audio_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + content_type_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::memset(&num_channels_, 0, static_cast( + reinterpret_cast(&sample_rate_) - + reinterpret_cast(&num_channels_)) + sizeof(sample_rate_)); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Audio::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // float sample_rate = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 13)) { + sample_rate_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // int64 num_channels = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + num_channels_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // int64 length_frames = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + length_frames_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes encoded_audio_string = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_encoded_audio_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string content_type = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + auto str = _internal_mutable_content_type(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.Summary.Audio.content_type")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Audio::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.Summary.Audio) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // float sample_rate = 1; + if (!(this->sample_rate() <= 0 && this->sample_rate() >= 0)) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(1, this->_internal_sample_rate(), target); + } + + // int64 num_channels = 2; + if (this->num_channels() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(2, this->_internal_num_channels(), target); + } + + // int64 length_frames = 3; + if (this->length_frames() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->_internal_length_frames(), target); + } + + // bytes encoded_audio_string = 4; + if (this->encoded_audio_string().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 4, this->_internal_encoded_audio_string(), target); + } + + // string content_type = 5; + if (this->content_type().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_content_type().data(), static_cast(this->_internal_content_type().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.Summary.Audio.content_type"); + target = stream->WriteStringMaybeAliased( + 5, this->_internal_content_type(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.Summary.Audio) + return target; +} + +size_t Summary_Audio::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.Summary.Audio) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // bytes encoded_audio_string = 4; + if (this->encoded_audio_string().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_encoded_audio_string()); + } + + // string content_type = 5; + if (this->content_type().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_content_type()); + } + + // int64 num_channels = 2; + if (this->num_channels() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_num_channels()); + } + + // int64 length_frames = 3; + if (this->length_frames() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_length_frames()); + } + + // float sample_rate = 1; + if (!(this->sample_rate() <= 0 && this->sample_rate() >= 0)) { + total_size += 1 + 4; + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Audio::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.Summary.Audio) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Audio* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.Summary.Audio) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.Summary.Audio) + MergeFrom(*source); + } +} + +void Summary_Audio::MergeFrom(const Summary_Audio& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.Summary.Audio) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.encoded_audio_string().size() > 0) { + _internal_set_encoded_audio_string(from._internal_encoded_audio_string()); + } + if (from.content_type().size() > 0) { + _internal_set_content_type(from._internal_content_type()); + } + if (from.num_channels() != 0) { + _internal_set_num_channels(from._internal_num_channels()); + } + if (from.length_frames() != 0) { + _internal_set_length_frames(from._internal_length_frames()); + } + if (!(from.sample_rate() <= 0 && from.sample_rate() >= 0)) { + _internal_set_sample_rate(from._internal_sample_rate()); + } +} + +void Summary_Audio::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.Summary.Audio) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Audio::CopyFrom(const Summary_Audio& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.Summary.Audio) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Audio::IsInitialized() const { + return true; +} + +void Summary_Audio::InternalSwap(Summary_Audio* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + encoded_audio_string_.Swap(&other->encoded_audio_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + content_type_.Swap(&other->content_type_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Summary_Audio, sample_rate_) + + sizeof(Summary_Audio::sample_rate_) + - PROTOBUF_FIELD_OFFSET(Summary_Audio, num_channels_)>( + reinterpret_cast(&num_channels_), + reinterpret_cast(&other->num_channels_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Audio::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary_Value::InitAsDefaultInstance() { + ::tensorboard::_Summary_Value_default_instance_._instance.get_mutable()->metadata_ = const_cast< ::tensorboard::SummaryMetadata*>( + ::tensorboard::SummaryMetadata::internal_default_instance()); + ::tensorboard::_Summary_Value_default_instance_.simple_value_ = 0; + ::tensorboard::_Summary_Value_default_instance_.obsolete_old_style_histogram_.UnsafeSetDefault( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::tensorboard::_Summary_Value_default_instance_.image_ = const_cast< ::tensorboard::Summary_Image*>( + ::tensorboard::Summary_Image::internal_default_instance()); + ::tensorboard::_Summary_Value_default_instance_.histo_ = const_cast< ::tensorboard::HistogramProto*>( + ::tensorboard::HistogramProto::internal_default_instance()); + ::tensorboard::_Summary_Value_default_instance_.audio_ = const_cast< ::tensorboard::Summary_Audio*>( + ::tensorboard::Summary_Audio::internal_default_instance()); + ::tensorboard::_Summary_Value_default_instance_.tensor_ = const_cast< ::tensorboard::TensorProto*>( + ::tensorboard::TensorProto::internal_default_instance()); +} +class Summary_Value::_Internal { + public: + static const ::tensorboard::SummaryMetadata& metadata(const Summary_Value* msg); + static const ::tensorboard::Summary_Image& image(const Summary_Value* msg); + static const ::tensorboard::HistogramProto& histo(const Summary_Value* msg); + static const ::tensorboard::Summary_Audio& audio(const Summary_Value* msg); + static const ::tensorboard::TensorProto& tensor(const Summary_Value* msg); +}; + +const ::tensorboard::SummaryMetadata& +Summary_Value::_Internal::metadata(const Summary_Value* msg) { + return *msg->metadata_; +} +const ::tensorboard::Summary_Image& +Summary_Value::_Internal::image(const Summary_Value* msg) { + return *msg->value_.image_; +} +const ::tensorboard::HistogramProto& +Summary_Value::_Internal::histo(const Summary_Value* msg) { + return *msg->value_.histo_; +} +const ::tensorboard::Summary_Audio& +Summary_Value::_Internal::audio(const Summary_Value* msg) { + return *msg->value_.audio_; +} +const ::tensorboard::TensorProto& +Summary_Value::_Internal::tensor(const Summary_Value* msg) { + return *msg->value_.tensor_; +} +void Summary_Value::set_allocated_image(::tensorboard::Summary_Image* image) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (image) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(image); + if (message_arena != submessage_arena) { + image = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, image, submessage_arena); + } + set_has_image(); + value_.image_ = image; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.image) +} +void Summary_Value::set_allocated_histo(::tensorboard::HistogramProto* histo) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (histo) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(histo)->GetArena(); + if (message_arena != submessage_arena) { + histo = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, histo, submessage_arena); + } + set_has_histo(); + value_.histo_ = histo; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.histo) +} +void Summary_Value::clear_histo() { + if (_internal_has_histo()) { + if (GetArena() == nullptr) { + delete value_.histo_; + } + clear_has_value(); + } +} +void Summary_Value::set_allocated_audio(::tensorboard::Summary_Audio* audio) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (audio) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(audio); + if (message_arena != submessage_arena) { + audio = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, audio, submessage_arena); + } + set_has_audio(); + value_.audio_ = audio; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.audio) +} +void Summary_Value::set_allocated_tensor(::tensorboard::TensorProto* tensor) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (tensor) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(tensor)->GetArena(); + if (message_arena != submessage_arena) { + tensor = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tensor, submessage_arena); + } + set_has_tensor(); + value_.tensor_ = tensor; + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.tensor) +} +void Summary_Value::clear_tensor() { + if (_internal_has_tensor()) { + if (GetArena() == nullptr) { + delete value_.tensor_; + } + clear_has_value(); + } +} +Summary_Value::Summary_Value(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.Summary.Value) +} +Summary_Value::Summary_Value(const Summary_Value& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + tag_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_tag().empty()) { + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_tag(), + GetArena()); + } + node_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_node_name().empty()) { + node_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_node_name(), + GetArena()); + } + if (from._internal_has_metadata()) { + metadata_ = new ::tensorboard::SummaryMetadata(*from.metadata_); + } else { + metadata_ = nullptr; + } + clear_has_value(); + switch (from.value_case()) { + case kSimpleValue: { + _internal_set_simple_value(from._internal_simple_value()); + break; + } + case kObsoleteOldStyleHistogram: { + _internal_set_obsolete_old_style_histogram(from._internal_obsolete_old_style_histogram()); + break; + } + case kImage: { + _internal_mutable_image()->::tensorboard::Summary_Image::MergeFrom(from._internal_image()); + break; + } + case kHisto: { + _internal_mutable_histo()->::tensorboard::HistogramProto::MergeFrom(from._internal_histo()); + break; + } + case kAudio: { + _internal_mutable_audio()->::tensorboard::Summary_Audio::MergeFrom(from._internal_audio()); + break; + } + case kTensor: { + _internal_mutable_tensor()->::tensorboard::TensorProto::MergeFrom(from._internal_tensor()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:tensorboard.Summary.Value) +} + +void Summary_Value::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_Value_summary_2eproto.base); + tag_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + node_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + metadata_ = nullptr; + clear_has_value(); +} + +Summary_Value::~Summary_Value() { + // @@protoc_insertion_point(destructor:tensorboard.Summary.Value) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary_Value::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + tag_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + node_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete metadata_; + if (has_value()) { + clear_value(); + } +} + +void Summary_Value::ArenaDtor(void* object) { + Summary_Value* _this = reinterpret_cast< Summary_Value* >(object); + (void)_this; +} +void Summary_Value::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary_Value::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary_Value& Summary_Value::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_Value_summary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary_Value::clear_value() { +// @@protoc_insertion_point(one_of_clear_start:tensorboard.Summary.Value) + switch (value_case()) { + case kSimpleValue: { + // No need to clear + break; + } + case kObsoleteOldStyleHistogram: { + value_.obsolete_old_style_histogram_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + break; + } + case kImage: { + if (GetArena() == nullptr) { + delete value_.image_; + } + break; + } + case kHisto: { + if (GetArena() == nullptr) { + delete value_.histo_; + } + break; + } + case kAudio: { + if (GetArena() == nullptr) { + delete value_.audio_; + } + break; + } + case kTensor: { + if (GetArena() == nullptr) { + delete value_.tensor_; + } + break; + } + case VALUE_NOT_SET: { + break; + } + } + _oneof_case_[0] = VALUE_NOT_SET; +} + + +void Summary_Value::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.Summary.Value) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + tag_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + node_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + if (GetArena() == nullptr && metadata_ != nullptr) { + delete metadata_; + } + metadata_ = nullptr; + clear_value(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary_Value::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string tag = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_tag(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.Summary.Value.tag")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // float simple_value = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 21)) { + _internal_set_simple_value(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // bytes obsolete_old_style_histogram = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_obsolete_old_style_histogram(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.Summary.Image image = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_image(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.HistogramProto histo = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ctx->ParseMessage(_internal_mutable_histo(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.Summary.Audio audio = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ctx->ParseMessage(_internal_mutable_audio(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string node_name = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + auto str = _internal_mutable_node_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.Summary.Value.node_name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.TensorProto tensor = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr = ctx->ParseMessage(_internal_mutable_tensor(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .tensorboard.SummaryMetadata metadata = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + ptr = ctx->ParseMessage(_internal_mutable_metadata(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary_Value::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.Summary.Value) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string tag = 1; + if (this->tag().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_tag().data(), static_cast(this->_internal_tag().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.Summary.Value.tag"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_tag(), target); + } + + // float simple_value = 2; + if (_internal_has_simple_value()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(2, this->_internal_simple_value(), target); + } + + // bytes obsolete_old_style_histogram = 3; + if (_internal_has_obsolete_old_style_histogram()) { + target = stream->WriteBytesMaybeAliased( + 3, this->_internal_obsolete_old_style_histogram(), target); + } + + // .tensorboard.Summary.Image image = 4; + if (_internal_has_image()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::image(this), target, stream); + } + + // .tensorboard.HistogramProto histo = 5; + if (_internal_has_histo()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 5, _Internal::histo(this), target, stream); + } + + // .tensorboard.Summary.Audio audio = 6; + if (_internal_has_audio()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 6, _Internal::audio(this), target, stream); + } + + // string node_name = 7; + if (this->node_name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_node_name().data(), static_cast(this->_internal_node_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.Summary.Value.node_name"); + target = stream->WriteStringMaybeAliased( + 7, this->_internal_node_name(), target); + } + + // .tensorboard.TensorProto tensor = 8; + if (_internal_has_tensor()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 8, _Internal::tensor(this), target, stream); + } + + // .tensorboard.SummaryMetadata metadata = 9; + if (this->has_metadata()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 9, _Internal::metadata(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.Summary.Value) + return target; +} + +size_t Summary_Value::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.Summary.Value) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string tag = 1; + if (this->tag().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_tag()); + } + + // string node_name = 7; + if (this->node_name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_node_name()); + } + + // .tensorboard.SummaryMetadata metadata = 9; + if (this->has_metadata()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *metadata_); + } + + switch (value_case()) { + // float simple_value = 2; + case kSimpleValue: { + total_size += 1 + 4; + break; + } + // bytes obsolete_old_style_histogram = 3; + case kObsoleteOldStyleHistogram: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_obsolete_old_style_histogram()); + break; + } + // .tensorboard.Summary.Image image = 4; + case kImage: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.image_); + break; + } + // .tensorboard.HistogramProto histo = 5; + case kHisto: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.histo_); + break; + } + // .tensorboard.Summary.Audio audio = 6; + case kAudio: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.audio_); + break; + } + // .tensorboard.TensorProto tensor = 8; + case kTensor: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.tensor_); + break; + } + case VALUE_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary_Value::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.Summary.Value) + GOOGLE_DCHECK_NE(&from, this); + const Summary_Value* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.Summary.Value) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.Summary.Value) + MergeFrom(*source); + } +} + +void Summary_Value::MergeFrom(const Summary_Value& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.Summary.Value) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.tag().size() > 0) { + _internal_set_tag(from._internal_tag()); + } + if (from.node_name().size() > 0) { + _internal_set_node_name(from._internal_node_name()); + } + if (from.has_metadata()) { + _internal_mutable_metadata()->::tensorboard::SummaryMetadata::MergeFrom(from._internal_metadata()); + } + switch (from.value_case()) { + case kSimpleValue: { + _internal_set_simple_value(from._internal_simple_value()); + break; + } + case kObsoleteOldStyleHistogram: { + _internal_set_obsolete_old_style_histogram(from._internal_obsolete_old_style_histogram()); + break; + } + case kImage: { + _internal_mutable_image()->::tensorboard::Summary_Image::MergeFrom(from._internal_image()); + break; + } + case kHisto: { + _internal_mutable_histo()->::tensorboard::HistogramProto::MergeFrom(from._internal_histo()); + break; + } + case kAudio: { + _internal_mutable_audio()->::tensorboard::Summary_Audio::MergeFrom(from._internal_audio()); + break; + } + case kTensor: { + _internal_mutable_tensor()->::tensorboard::TensorProto::MergeFrom(from._internal_tensor()); + break; + } + case VALUE_NOT_SET: { + break; + } + } +} + +void Summary_Value::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.Summary.Value) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary_Value::CopyFrom(const Summary_Value& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.Summary.Value) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary_Value::IsInitialized() const { + return true; +} + +void Summary_Value::InternalSwap(Summary_Value* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + tag_.Swap(&other->tag_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + node_name_.Swap(&other->node_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(metadata_, other->metadata_); + swap(value_, other->value_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary_Value::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void Summary::InitAsDefaultInstance() { +} +class Summary::_Internal { + public: +}; + +Summary::Summary(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + value_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.Summary) +} +Summary::Summary(const Summary& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + value_(from.value_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:tensorboard.Summary) +} + +void Summary::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_Summary_summary_2eproto.base); +} + +Summary::~Summary() { + // @@protoc_insertion_point(destructor:tensorboard.Summary) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void Summary::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void Summary::ArenaDtor(void* object) { + Summary* _this = reinterpret_cast< Summary* >(object); + (void)_this; +} +void Summary::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void Summary::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const Summary& Summary::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_Summary_summary_2eproto.base); + return *internal_default_instance(); +} + + +void Summary::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.Summary) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + value_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* Summary::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .tensorboard.Summary.Value value = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_value(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* Summary::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.Summary) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .tensorboard.Summary.Value value = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_value_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_value(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.Summary) + return target; +} + +size_t Summary::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.Summary) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .tensorboard.Summary.Value value = 1; + total_size += 1UL * this->_internal_value_size(); + for (const auto& msg : this->value_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void Summary::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.Summary) + GOOGLE_DCHECK_NE(&from, this); + const Summary* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.Summary) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.Summary) + MergeFrom(*source); + } +} + +void Summary::MergeFrom(const Summary& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.Summary) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + value_.MergeFrom(from.value_); +} + +void Summary::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.Summary) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void Summary::CopyFrom(const Summary& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.Summary) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool Summary::IsInitialized() const { + return true; +} + +void Summary::InternalSwap(Summary* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + value_.InternalSwap(&other->value_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata Summary::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::SummaryDescription* Arena::CreateMaybeMessage< ::tensorboard::SummaryDescription >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::SummaryDescription >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::SummaryMetadata_PluginData* Arena::CreateMaybeMessage< ::tensorboard::SummaryMetadata_PluginData >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::SummaryMetadata_PluginData >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::SummaryMetadata* Arena::CreateMaybeMessage< ::tensorboard::SummaryMetadata >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::SummaryMetadata >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::Summary_Image* Arena::CreateMaybeMessage< ::tensorboard::Summary_Image >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::Summary_Image >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::Summary_Audio* Arena::CreateMaybeMessage< ::tensorboard::Summary_Audio >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::Summary_Audio >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::Summary_Value* Arena::CreateMaybeMessage< ::tensorboard::Summary_Value >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::Summary_Value >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::Summary* Arena::CreateMaybeMessage< ::tensorboard::Summary >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::Summary >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/summary.pb.h b/plugins/mindstudio-insight-plugins/proto/summary.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..3526fa162dd23bafad05178b8cfc1875d39c959a --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/summary.pb.h @@ -0,0 +1,2926 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: summary.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_summary_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_summary_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +#include "histogram.pb.h" +#include "tensor.pb.h" +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_summary_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_summary_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[7] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_summary_2eproto; +namespace tensorboard { +class Summary; +class SummaryDefaultTypeInternal; +extern SummaryDefaultTypeInternal _Summary_default_instance_; +class SummaryDescription; +class SummaryDescriptionDefaultTypeInternal; +extern SummaryDescriptionDefaultTypeInternal _SummaryDescription_default_instance_; +class SummaryMetadata; +class SummaryMetadataDefaultTypeInternal; +extern SummaryMetadataDefaultTypeInternal _SummaryMetadata_default_instance_; +class SummaryMetadata_PluginData; +class SummaryMetadata_PluginDataDefaultTypeInternal; +extern SummaryMetadata_PluginDataDefaultTypeInternal _SummaryMetadata_PluginData_default_instance_; +class Summary_Audio; +class Summary_AudioDefaultTypeInternal; +extern Summary_AudioDefaultTypeInternal _Summary_Audio_default_instance_; +class Summary_Image; +class Summary_ImageDefaultTypeInternal; +extern Summary_ImageDefaultTypeInternal _Summary_Image_default_instance_; +class Summary_Value; +class Summary_ValueDefaultTypeInternal; +extern Summary_ValueDefaultTypeInternal _Summary_Value_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::Summary* Arena::CreateMaybeMessage<::tensorboard::Summary>(Arena*); +template<> ::tensorboard::SummaryDescription* Arena::CreateMaybeMessage<::tensorboard::SummaryDescription>(Arena*); +template<> ::tensorboard::SummaryMetadata* Arena::CreateMaybeMessage<::tensorboard::SummaryMetadata>(Arena*); +template<> ::tensorboard::SummaryMetadata_PluginData* Arena::CreateMaybeMessage<::tensorboard::SummaryMetadata_PluginData>(Arena*); +template<> ::tensorboard::Summary_Audio* Arena::CreateMaybeMessage<::tensorboard::Summary_Audio>(Arena*); +template<> ::tensorboard::Summary_Image* Arena::CreateMaybeMessage<::tensorboard::Summary_Image>(Arena*); +template<> ::tensorboard::Summary_Value* Arena::CreateMaybeMessage<::tensorboard::Summary_Value>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +enum DataClass : int { + DATA_CLASS_UNKNOWN = 0, + DATA_CLASS_SCALAR = 1, + DATA_CLASS_TENSOR = 2, + DATA_CLASS_BLOB_SEQUENCE = 3, + DataClass_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + DataClass_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +bool DataClass_IsValid(int value); +constexpr DataClass DataClass_MIN = DATA_CLASS_UNKNOWN; +constexpr DataClass DataClass_MAX = DATA_CLASS_BLOB_SEQUENCE; +constexpr int DataClass_ARRAYSIZE = DataClass_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataClass_descriptor(); +template +inline const std::string& DataClass_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function DataClass_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + DataClass_descriptor(), enum_t_value); +} +inline bool DataClass_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, DataClass* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + DataClass_descriptor(), name, value); +} +// =================================================================== + +class SummaryDescription PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.SummaryDescription) */ { + public: + inline SummaryDescription() : SummaryDescription(nullptr) {} + virtual ~SummaryDescription(); + + SummaryDescription(const SummaryDescription& from); + SummaryDescription(SummaryDescription&& from) noexcept + : SummaryDescription() { + *this = ::std::move(from); + } + + inline SummaryDescription& operator=(const SummaryDescription& from) { + CopyFrom(from); + return *this; + } + inline SummaryDescription& operator=(SummaryDescription&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SummaryDescription& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SummaryDescription* internal_default_instance() { + return reinterpret_cast( + &_SummaryDescription_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(SummaryDescription& a, SummaryDescription& b) { + a.Swap(&b); + } + inline void Swap(SummaryDescription* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SummaryDescription* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SummaryDescription* New() const final { + return CreateMaybeMessage(nullptr); + } + + SummaryDescription* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SummaryDescription& from); + void MergeFrom(const SummaryDescription& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SummaryDescription* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.SummaryDescription"; + } + protected: + explicit SummaryDescription(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTypeHintFieldNumber = 1, + }; + // string type_hint = 1; + void clear_type_hint(); + const std::string& type_hint() const; + void set_type_hint(const std::string& value); + void set_type_hint(std::string&& value); + void set_type_hint(const char* value); + void set_type_hint(const char* value, size_t size); + std::string* mutable_type_hint(); + std::string* release_type_hint(); + void set_allocated_type_hint(std::string* type_hint); + private: + const std::string& _internal_type_hint() const; + void _internal_set_type_hint(const std::string& value); + std::string* _internal_mutable_type_hint(); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.SummaryDescription) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr type_hint_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_summary_2eproto; +}; +// ------------------------------------------------------------------- + +class SummaryMetadata_PluginData PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.SummaryMetadata.PluginData) */ { + public: + inline SummaryMetadata_PluginData() : SummaryMetadata_PluginData(nullptr) {} + virtual ~SummaryMetadata_PluginData(); + + SummaryMetadata_PluginData(const SummaryMetadata_PluginData& from); + SummaryMetadata_PluginData(SummaryMetadata_PluginData&& from) noexcept + : SummaryMetadata_PluginData() { + *this = ::std::move(from); + } + + inline SummaryMetadata_PluginData& operator=(const SummaryMetadata_PluginData& from) { + CopyFrom(from); + return *this; + } + inline SummaryMetadata_PluginData& operator=(SummaryMetadata_PluginData&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SummaryMetadata_PluginData& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SummaryMetadata_PluginData* internal_default_instance() { + return reinterpret_cast( + &_SummaryMetadata_PluginData_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(SummaryMetadata_PluginData& a, SummaryMetadata_PluginData& b) { + a.Swap(&b); + } + inline void Swap(SummaryMetadata_PluginData* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SummaryMetadata_PluginData* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SummaryMetadata_PluginData* New() const final { + return CreateMaybeMessage(nullptr); + } + + SummaryMetadata_PluginData* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SummaryMetadata_PluginData& from); + void MergeFrom(const SummaryMetadata_PluginData& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SummaryMetadata_PluginData* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.SummaryMetadata.PluginData"; + } + protected: + explicit SummaryMetadata_PluginData(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kPluginNameFieldNumber = 1, + kContentFieldNumber = 2, + }; + // string plugin_name = 1; + void clear_plugin_name(); + const std::string& plugin_name() const; + void set_plugin_name(const std::string& value); + void set_plugin_name(std::string&& value); + void set_plugin_name(const char* value); + void set_plugin_name(const char* value, size_t size); + std::string* mutable_plugin_name(); + std::string* release_plugin_name(); + void set_allocated_plugin_name(std::string* plugin_name); + private: + const std::string& _internal_plugin_name() const; + void _internal_set_plugin_name(const std::string& value); + std::string* _internal_mutable_plugin_name(); + public: + + // bytes content = 2; + void clear_content(); + const std::string& content() const; + void set_content(const std::string& value); + void set_content(std::string&& value); + void set_content(const char* value); + void set_content(const void* value, size_t size); + std::string* mutable_content(); + std::string* release_content(); + void set_allocated_content(std::string* content); + private: + const std::string& _internal_content() const; + void _internal_set_content(const std::string& value); + std::string* _internal_mutable_content(); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.SummaryMetadata.PluginData) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr plugin_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr content_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_summary_2eproto; +}; +// ------------------------------------------------------------------- + +class SummaryMetadata PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.SummaryMetadata) */ { + public: + inline SummaryMetadata() : SummaryMetadata(nullptr) {} + virtual ~SummaryMetadata(); + + SummaryMetadata(const SummaryMetadata& from); + SummaryMetadata(SummaryMetadata&& from) noexcept + : SummaryMetadata() { + *this = ::std::move(from); + } + + inline SummaryMetadata& operator=(const SummaryMetadata& from) { + CopyFrom(from); + return *this; + } + inline SummaryMetadata& operator=(SummaryMetadata&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SummaryMetadata& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SummaryMetadata* internal_default_instance() { + return reinterpret_cast( + &_SummaryMetadata_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(SummaryMetadata& a, SummaryMetadata& b) { + a.Swap(&b); + } + inline void Swap(SummaryMetadata* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SummaryMetadata* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SummaryMetadata* New() const final { + return CreateMaybeMessage(nullptr); + } + + SummaryMetadata* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SummaryMetadata& from); + void MergeFrom(const SummaryMetadata& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SummaryMetadata* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.SummaryMetadata"; + } + protected: + explicit SummaryMetadata(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef SummaryMetadata_PluginData PluginData; + + // accessors ------------------------------------------------------- + + enum : int { + kDisplayNameFieldNumber = 2, + kSummaryDescriptionFieldNumber = 3, + kPluginDataFieldNumber = 1, + kDataClassFieldNumber = 4, + }; + // string display_name = 2; + void clear_display_name(); + const std::string& display_name() const; + void set_display_name(const std::string& value); + void set_display_name(std::string&& value); + void set_display_name(const char* value); + void set_display_name(const char* value, size_t size); + std::string* mutable_display_name(); + std::string* release_display_name(); + void set_allocated_display_name(std::string* display_name); + private: + const std::string& _internal_display_name() const; + void _internal_set_display_name(const std::string& value); + std::string* _internal_mutable_display_name(); + public: + + // string summary_description = 3; + void clear_summary_description(); + const std::string& summary_description() const; + void set_summary_description(const std::string& value); + void set_summary_description(std::string&& value); + void set_summary_description(const char* value); + void set_summary_description(const char* value, size_t size); + std::string* mutable_summary_description(); + std::string* release_summary_description(); + void set_allocated_summary_description(std::string* summary_description); + private: + const std::string& _internal_summary_description() const; + void _internal_set_summary_description(const std::string& value); + std::string* _internal_mutable_summary_description(); + public: + + // .tensorboard.SummaryMetadata.PluginData plugin_data = 1; + bool has_plugin_data() const; + private: + bool _internal_has_plugin_data() const; + public: + void clear_plugin_data(); + const ::tensorboard::SummaryMetadata_PluginData& plugin_data() const; + ::tensorboard::SummaryMetadata_PluginData* release_plugin_data(); + ::tensorboard::SummaryMetadata_PluginData* mutable_plugin_data(); + void set_allocated_plugin_data(::tensorboard::SummaryMetadata_PluginData* plugin_data); + private: + const ::tensorboard::SummaryMetadata_PluginData& _internal_plugin_data() const; + ::tensorboard::SummaryMetadata_PluginData* _internal_mutable_plugin_data(); + public: + void unsafe_arena_set_allocated_plugin_data( + ::tensorboard::SummaryMetadata_PluginData* plugin_data); + ::tensorboard::SummaryMetadata_PluginData* unsafe_arena_release_plugin_data(); + + // .tensorboard.DataClass data_class = 4; + void clear_data_class(); + ::tensorboard::DataClass data_class() const; + void set_data_class(::tensorboard::DataClass value); + private: + ::tensorboard::DataClass _internal_data_class() const; + void _internal_set_data_class(::tensorboard::DataClass value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.SummaryMetadata) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr display_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr summary_description_; + ::tensorboard::SummaryMetadata_PluginData* plugin_data_; + int data_class_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_summary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Image PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.Summary.Image) */ { + public: + inline Summary_Image() : Summary_Image(nullptr) {} + virtual ~Summary_Image(); + + Summary_Image(const Summary_Image& from); + Summary_Image(Summary_Image&& from) noexcept + : Summary_Image() { + *this = ::std::move(from); + } + + inline Summary_Image& operator=(const Summary_Image& from) { + CopyFrom(from); + return *this; + } + inline Summary_Image& operator=(Summary_Image&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Image& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Image* internal_default_instance() { + return reinterpret_cast( + &_Summary_Image_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(Summary_Image& a, Summary_Image& b) { + a.Swap(&b); + } + inline void Swap(Summary_Image* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Image* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Image* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Image* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Image& from); + void MergeFrom(const Summary_Image& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Image* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.Summary.Image"; + } + protected: + explicit Summary_Image(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kEncodedImageStringFieldNumber = 4, + kHeightFieldNumber = 1, + kWidthFieldNumber = 2, + kColorspaceFieldNumber = 3, + }; + // bytes encoded_image_string = 4; + void clear_encoded_image_string(); + const std::string& encoded_image_string() const; + void set_encoded_image_string(const std::string& value); + void set_encoded_image_string(std::string&& value); + void set_encoded_image_string(const char* value); + void set_encoded_image_string(const void* value, size_t size); + std::string* mutable_encoded_image_string(); + std::string* release_encoded_image_string(); + void set_allocated_encoded_image_string(std::string* encoded_image_string); + private: + const std::string& _internal_encoded_image_string() const; + void _internal_set_encoded_image_string(const std::string& value); + std::string* _internal_mutable_encoded_image_string(); + public: + + // int32 height = 1; + void clear_height(); + ::PROTOBUF_NAMESPACE_ID::int32 height() const; + void set_height(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_height() const; + void _internal_set_height(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // int32 width = 2; + void clear_width(); + ::PROTOBUF_NAMESPACE_ID::int32 width() const; + void set_width(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_width() const; + void _internal_set_width(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // int32 colorspace = 3; + void clear_colorspace(); + ::PROTOBUF_NAMESPACE_ID::int32 colorspace() const; + void set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_colorspace() const; + void _internal_set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.Summary.Image) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr encoded_image_string_; + ::PROTOBUF_NAMESPACE_ID::int32 height_; + ::PROTOBUF_NAMESPACE_ID::int32 width_; + ::PROTOBUF_NAMESPACE_ID::int32 colorspace_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_summary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Audio PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.Summary.Audio) */ { + public: + inline Summary_Audio() : Summary_Audio(nullptr) {} + virtual ~Summary_Audio(); + + Summary_Audio(const Summary_Audio& from); + Summary_Audio(Summary_Audio&& from) noexcept + : Summary_Audio() { + *this = ::std::move(from); + } + + inline Summary_Audio& operator=(const Summary_Audio& from) { + CopyFrom(from); + return *this; + } + inline Summary_Audio& operator=(Summary_Audio&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Audio& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Audio* internal_default_instance() { + return reinterpret_cast( + &_Summary_Audio_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(Summary_Audio& a, Summary_Audio& b) { + a.Swap(&b); + } + inline void Swap(Summary_Audio* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Audio* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Audio* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Audio* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Audio& from); + void MergeFrom(const Summary_Audio& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Audio* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.Summary.Audio"; + } + protected: + explicit Summary_Audio(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kEncodedAudioStringFieldNumber = 4, + kContentTypeFieldNumber = 5, + kNumChannelsFieldNumber = 2, + kLengthFramesFieldNumber = 3, + kSampleRateFieldNumber = 1, + }; + // bytes encoded_audio_string = 4; + void clear_encoded_audio_string(); + const std::string& encoded_audio_string() const; + void set_encoded_audio_string(const std::string& value); + void set_encoded_audio_string(std::string&& value); + void set_encoded_audio_string(const char* value); + void set_encoded_audio_string(const void* value, size_t size); + std::string* mutable_encoded_audio_string(); + std::string* release_encoded_audio_string(); + void set_allocated_encoded_audio_string(std::string* encoded_audio_string); + private: + const std::string& _internal_encoded_audio_string() const; + void _internal_set_encoded_audio_string(const std::string& value); + std::string* _internal_mutable_encoded_audio_string(); + public: + + // string content_type = 5; + void clear_content_type(); + const std::string& content_type() const; + void set_content_type(const std::string& value); + void set_content_type(std::string&& value); + void set_content_type(const char* value); + void set_content_type(const char* value, size_t size); + std::string* mutable_content_type(); + std::string* release_content_type(); + void set_allocated_content_type(std::string* content_type); + private: + const std::string& _internal_content_type() const; + void _internal_set_content_type(const std::string& value); + std::string* _internal_mutable_content_type(); + public: + + // int64 num_channels = 2; + void clear_num_channels(); + ::PROTOBUF_NAMESPACE_ID::int64 num_channels() const; + void set_num_channels(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_num_channels() const; + void _internal_set_num_channels(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // int64 length_frames = 3; + void clear_length_frames(); + ::PROTOBUF_NAMESPACE_ID::int64 length_frames() const; + void set_length_frames(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_length_frames() const; + void _internal_set_length_frames(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // float sample_rate = 1; + void clear_sample_rate(); + float sample_rate() const; + void set_sample_rate(float value); + private: + float _internal_sample_rate() const; + void _internal_set_sample_rate(float value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.Summary.Audio) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr encoded_audio_string_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr content_type_; + ::PROTOBUF_NAMESPACE_ID::int64 num_channels_; + ::PROTOBUF_NAMESPACE_ID::int64 length_frames_; + float sample_rate_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_summary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary_Value PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.Summary.Value) */ { + public: + inline Summary_Value() : Summary_Value(nullptr) {} + virtual ~Summary_Value(); + + Summary_Value(const Summary_Value& from); + Summary_Value(Summary_Value&& from) noexcept + : Summary_Value() { + *this = ::std::move(from); + } + + inline Summary_Value& operator=(const Summary_Value& from) { + CopyFrom(from); + return *this; + } + inline Summary_Value& operator=(Summary_Value&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary_Value& default_instance(); + + enum ValueCase { + kSimpleValue = 2, + kObsoleteOldStyleHistogram = 3, + kImage = 4, + kHisto = 5, + kAudio = 6, + kTensor = 8, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary_Value* internal_default_instance() { + return reinterpret_cast( + &_Summary_Value_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(Summary_Value& a, Summary_Value& b) { + a.Swap(&b); + } + inline void Swap(Summary_Value* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary_Value* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary_Value* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary_Value* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary_Value& from); + void MergeFrom(const Summary_Value& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary_Value* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.Summary.Value"; + } + protected: + explicit Summary_Value(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTagFieldNumber = 1, + kNodeNameFieldNumber = 7, + kMetadataFieldNumber = 9, + kSimpleValueFieldNumber = 2, + kObsoleteOldStyleHistogramFieldNumber = 3, + kImageFieldNumber = 4, + kHistoFieldNumber = 5, + kAudioFieldNumber = 6, + kTensorFieldNumber = 8, + }; + // string tag = 1; + void clear_tag(); + const std::string& tag() const; + void set_tag(const std::string& value); + void set_tag(std::string&& value); + void set_tag(const char* value); + void set_tag(const char* value, size_t size); + std::string* mutable_tag(); + std::string* release_tag(); + void set_allocated_tag(std::string* tag); + private: + const std::string& _internal_tag() const; + void _internal_set_tag(const std::string& value); + std::string* _internal_mutable_tag(); + public: + + // string node_name = 7; + void clear_node_name(); + const std::string& node_name() const; + void set_node_name(const std::string& value); + void set_node_name(std::string&& value); + void set_node_name(const char* value); + void set_node_name(const char* value, size_t size); + std::string* mutable_node_name(); + std::string* release_node_name(); + void set_allocated_node_name(std::string* node_name); + private: + const std::string& _internal_node_name() const; + void _internal_set_node_name(const std::string& value); + std::string* _internal_mutable_node_name(); + public: + + // .tensorboard.SummaryMetadata metadata = 9; + bool has_metadata() const; + private: + bool _internal_has_metadata() const; + public: + void clear_metadata(); + const ::tensorboard::SummaryMetadata& metadata() const; + ::tensorboard::SummaryMetadata* release_metadata(); + ::tensorboard::SummaryMetadata* mutable_metadata(); + void set_allocated_metadata(::tensorboard::SummaryMetadata* metadata); + private: + const ::tensorboard::SummaryMetadata& _internal_metadata() const; + ::tensorboard::SummaryMetadata* _internal_mutable_metadata(); + public: + void unsafe_arena_set_allocated_metadata( + ::tensorboard::SummaryMetadata* metadata); + ::tensorboard::SummaryMetadata* unsafe_arena_release_metadata(); + + // float simple_value = 2; + private: + bool _internal_has_simple_value() const; + public: + void clear_simple_value(); + float simple_value() const; + void set_simple_value(float value); + private: + float _internal_simple_value() const; + void _internal_set_simple_value(float value); + public: + + // bytes obsolete_old_style_histogram = 3; + private: + bool _internal_has_obsolete_old_style_histogram() const; + public: + void clear_obsolete_old_style_histogram(); + const std::string& obsolete_old_style_histogram() const; + void set_obsolete_old_style_histogram(const std::string& value); + void set_obsolete_old_style_histogram(std::string&& value); + void set_obsolete_old_style_histogram(const char* value); + void set_obsolete_old_style_histogram(const void* value, size_t size); + std::string* mutable_obsolete_old_style_histogram(); + std::string* release_obsolete_old_style_histogram(); + void set_allocated_obsolete_old_style_histogram(std::string* obsolete_old_style_histogram); + private: + const std::string& _internal_obsolete_old_style_histogram() const; + void _internal_set_obsolete_old_style_histogram(const std::string& value); + std::string* _internal_mutable_obsolete_old_style_histogram(); + public: + + // .tensorboard.Summary.Image image = 4; + bool has_image() const; + private: + bool _internal_has_image() const; + public: + void clear_image(); + const ::tensorboard::Summary_Image& image() const; + ::tensorboard::Summary_Image* release_image(); + ::tensorboard::Summary_Image* mutable_image(); + void set_allocated_image(::tensorboard::Summary_Image* image); + private: + const ::tensorboard::Summary_Image& _internal_image() const; + ::tensorboard::Summary_Image* _internal_mutable_image(); + public: + void unsafe_arena_set_allocated_image( + ::tensorboard::Summary_Image* image); + ::tensorboard::Summary_Image* unsafe_arena_release_image(); + + // .tensorboard.HistogramProto histo = 5; + bool has_histo() const; + private: + bool _internal_has_histo() const; + public: + void clear_histo(); + const ::tensorboard::HistogramProto& histo() const; + ::tensorboard::HistogramProto* release_histo(); + ::tensorboard::HistogramProto* mutable_histo(); + void set_allocated_histo(::tensorboard::HistogramProto* histo); + private: + const ::tensorboard::HistogramProto& _internal_histo() const; + ::tensorboard::HistogramProto* _internal_mutable_histo(); + public: + void unsafe_arena_set_allocated_histo( + ::tensorboard::HistogramProto* histo); + ::tensorboard::HistogramProto* unsafe_arena_release_histo(); + + // .tensorboard.Summary.Audio audio = 6; + bool has_audio() const; + private: + bool _internal_has_audio() const; + public: + void clear_audio(); + const ::tensorboard::Summary_Audio& audio() const; + ::tensorboard::Summary_Audio* release_audio(); + ::tensorboard::Summary_Audio* mutable_audio(); + void set_allocated_audio(::tensorboard::Summary_Audio* audio); + private: + const ::tensorboard::Summary_Audio& _internal_audio() const; + ::tensorboard::Summary_Audio* _internal_mutable_audio(); + public: + void unsafe_arena_set_allocated_audio( + ::tensorboard::Summary_Audio* audio); + ::tensorboard::Summary_Audio* unsafe_arena_release_audio(); + + // .tensorboard.TensorProto tensor = 8; + bool has_tensor() const; + private: + bool _internal_has_tensor() const; + public: + void clear_tensor(); + const ::tensorboard::TensorProto& tensor() const; + ::tensorboard::TensorProto* release_tensor(); + ::tensorboard::TensorProto* mutable_tensor(); + void set_allocated_tensor(::tensorboard::TensorProto* tensor); + private: + const ::tensorboard::TensorProto& _internal_tensor() const; + ::tensorboard::TensorProto* _internal_mutable_tensor(); + public: + void unsafe_arena_set_allocated_tensor( + ::tensorboard::TensorProto* tensor); + ::tensorboard::TensorProto* unsafe_arena_release_tensor(); + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:tensorboard.Summary.Value) + private: + class _Internal; + void set_has_simple_value(); + void set_has_obsolete_old_style_histogram(); + void set_has_image(); + void set_has_histo(); + void set_has_audio(); + void set_has_tensor(); + + inline bool has_value() const; + inline void clear_has_value(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr tag_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr node_name_; + ::tensorboard::SummaryMetadata* metadata_; + union ValueUnion { + ValueUnion() {} + float simple_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr obsolete_old_style_histogram_; + ::tensorboard::Summary_Image* image_; + ::tensorboard::HistogramProto* histo_; + ::tensorboard::Summary_Audio* audio_; + ::tensorboard::TensorProto* tensor_; + } value_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_summary_2eproto; +}; +// ------------------------------------------------------------------- + +class Summary PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.Summary) */ { + public: + inline Summary() : Summary(nullptr) {} + virtual ~Summary(); + + Summary(const Summary& from); + Summary(Summary&& from) noexcept + : Summary() { + *this = ::std::move(from); + } + + inline Summary& operator=(const Summary& from) { + CopyFrom(from); + return *this; + } + inline Summary& operator=(Summary&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const Summary& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const Summary* internal_default_instance() { + return reinterpret_cast( + &_Summary_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(Summary& a, Summary& b) { + a.Swap(&b); + } + inline void Swap(Summary* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(Summary* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline Summary* New() const final { + return CreateMaybeMessage(nullptr); + } + + Summary* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const Summary& from); + void MergeFrom(const Summary& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(Summary* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.Summary"; + } + protected: + explicit Summary(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_summary_2eproto); + return ::descriptor_table_summary_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef Summary_Image Image; + typedef Summary_Audio Audio; + typedef Summary_Value Value; + + // accessors ------------------------------------------------------- + + enum : int { + kValueFieldNumber = 1, + }; + // repeated .tensorboard.Summary.Value value = 1; + int value_size() const; + private: + int _internal_value_size() const; + public: + void clear_value(); + ::tensorboard::Summary_Value* mutable_value(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Summary_Value >* + mutable_value(); + private: + const ::tensorboard::Summary_Value& _internal_value(int index) const; + ::tensorboard::Summary_Value* _internal_add_value(); + public: + const ::tensorboard::Summary_Value& value(int index) const; + ::tensorboard::Summary_Value* add_value(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Summary_Value >& + value() const; + + // @@protoc_insertion_point(class_scope:tensorboard.Summary) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Summary_Value > value_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_summary_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// SummaryDescription + +// string type_hint = 1; +inline void SummaryDescription::clear_type_hint() { + type_hint_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SummaryDescription::type_hint() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryDescription.type_hint) + return _internal_type_hint(); +} +inline void SummaryDescription::set_type_hint(const std::string& value) { + _internal_set_type_hint(value); + // @@protoc_insertion_point(field_set:tensorboard.SummaryDescription.type_hint) +} +inline std::string* SummaryDescription::mutable_type_hint() { + // @@protoc_insertion_point(field_mutable:tensorboard.SummaryDescription.type_hint) + return _internal_mutable_type_hint(); +} +inline const std::string& SummaryDescription::_internal_type_hint() const { + return type_hint_.Get(); +} +inline void SummaryDescription::_internal_set_type_hint(const std::string& value) { + + type_hint_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SummaryDescription::set_type_hint(std::string&& value) { + + type_hint_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SummaryDescription.type_hint) +} +inline void SummaryDescription::set_type_hint(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + type_hint_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SummaryDescription.type_hint) +} +inline void SummaryDescription::set_type_hint(const char* value, + size_t size) { + + type_hint_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SummaryDescription.type_hint) +} +inline std::string* SummaryDescription::_internal_mutable_type_hint() { + + return type_hint_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SummaryDescription::release_type_hint() { + // @@protoc_insertion_point(field_release:tensorboard.SummaryDescription.type_hint) + return type_hint_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SummaryDescription::set_allocated_type_hint(std::string* type_hint) { + if (type_hint != nullptr) { + + } else { + + } + type_hint_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), type_hint, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SummaryDescription.type_hint) +} + +// ------------------------------------------------------------------- + +// SummaryMetadata_PluginData + +// string plugin_name = 1; +inline void SummaryMetadata_PluginData::clear_plugin_name() { + plugin_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SummaryMetadata_PluginData::plugin_name() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryMetadata.PluginData.plugin_name) + return _internal_plugin_name(); +} +inline void SummaryMetadata_PluginData::set_plugin_name(const std::string& value) { + _internal_set_plugin_name(value); + // @@protoc_insertion_point(field_set:tensorboard.SummaryMetadata.PluginData.plugin_name) +} +inline std::string* SummaryMetadata_PluginData::mutable_plugin_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.SummaryMetadata.PluginData.plugin_name) + return _internal_mutable_plugin_name(); +} +inline const std::string& SummaryMetadata_PluginData::_internal_plugin_name() const { + return plugin_name_.Get(); +} +inline void SummaryMetadata_PluginData::_internal_set_plugin_name(const std::string& value) { + + plugin_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SummaryMetadata_PluginData::set_plugin_name(std::string&& value) { + + plugin_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SummaryMetadata.PluginData.plugin_name) +} +inline void SummaryMetadata_PluginData::set_plugin_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + plugin_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SummaryMetadata.PluginData.plugin_name) +} +inline void SummaryMetadata_PluginData::set_plugin_name(const char* value, + size_t size) { + + plugin_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SummaryMetadata.PluginData.plugin_name) +} +inline std::string* SummaryMetadata_PluginData::_internal_mutable_plugin_name() { + + return plugin_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SummaryMetadata_PluginData::release_plugin_name() { + // @@protoc_insertion_point(field_release:tensorboard.SummaryMetadata.PluginData.plugin_name) + return plugin_name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SummaryMetadata_PluginData::set_allocated_plugin_name(std::string* plugin_name) { + if (plugin_name != nullptr) { + + } else { + + } + plugin_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), plugin_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SummaryMetadata.PluginData.plugin_name) +} + +// bytes content = 2; +inline void SummaryMetadata_PluginData::clear_content() { + content_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SummaryMetadata_PluginData::content() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryMetadata.PluginData.content) + return _internal_content(); +} +inline void SummaryMetadata_PluginData::set_content(const std::string& value) { + _internal_set_content(value); + // @@protoc_insertion_point(field_set:tensorboard.SummaryMetadata.PluginData.content) +} +inline std::string* SummaryMetadata_PluginData::mutable_content() { + // @@protoc_insertion_point(field_mutable:tensorboard.SummaryMetadata.PluginData.content) + return _internal_mutable_content(); +} +inline const std::string& SummaryMetadata_PluginData::_internal_content() const { + return content_.Get(); +} +inline void SummaryMetadata_PluginData::_internal_set_content(const std::string& value) { + + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SummaryMetadata_PluginData::set_content(std::string&& value) { + + content_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SummaryMetadata.PluginData.content) +} +inline void SummaryMetadata_PluginData::set_content(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SummaryMetadata.PluginData.content) +} +inline void SummaryMetadata_PluginData::set_content(const void* value, + size_t size) { + + content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SummaryMetadata.PluginData.content) +} +inline std::string* SummaryMetadata_PluginData::_internal_mutable_content() { + + return content_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SummaryMetadata_PluginData::release_content() { + // @@protoc_insertion_point(field_release:tensorboard.SummaryMetadata.PluginData.content) + return content_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SummaryMetadata_PluginData::set_allocated_content(std::string* content) { + if (content != nullptr) { + + } else { + + } + content_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), content, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SummaryMetadata.PluginData.content) +} + +// ------------------------------------------------------------------- + +// SummaryMetadata + +// .tensorboard.SummaryMetadata.PluginData plugin_data = 1; +inline bool SummaryMetadata::_internal_has_plugin_data() const { + return this != internal_default_instance() && plugin_data_ != nullptr; +} +inline bool SummaryMetadata::has_plugin_data() const { + return _internal_has_plugin_data(); +} +inline void SummaryMetadata::clear_plugin_data() { + if (GetArena() == nullptr && plugin_data_ != nullptr) { + delete plugin_data_; + } + plugin_data_ = nullptr; +} +inline const ::tensorboard::SummaryMetadata_PluginData& SummaryMetadata::_internal_plugin_data() const { + const ::tensorboard::SummaryMetadata_PluginData* p = plugin_data_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_SummaryMetadata_PluginData_default_instance_); +} +inline const ::tensorboard::SummaryMetadata_PluginData& SummaryMetadata::plugin_data() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryMetadata.plugin_data) + return _internal_plugin_data(); +} +inline void SummaryMetadata::unsafe_arena_set_allocated_plugin_data( + ::tensorboard::SummaryMetadata_PluginData* plugin_data) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(plugin_data_); + } + plugin_data_ = plugin_data; + if (plugin_data) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.SummaryMetadata.plugin_data) +} +inline ::tensorboard::SummaryMetadata_PluginData* SummaryMetadata::release_plugin_data() { + + ::tensorboard::SummaryMetadata_PluginData* temp = plugin_data_; + plugin_data_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::SummaryMetadata_PluginData* SummaryMetadata::unsafe_arena_release_plugin_data() { + // @@protoc_insertion_point(field_release:tensorboard.SummaryMetadata.plugin_data) + + ::tensorboard::SummaryMetadata_PluginData* temp = plugin_data_; + plugin_data_ = nullptr; + return temp; +} +inline ::tensorboard::SummaryMetadata_PluginData* SummaryMetadata::_internal_mutable_plugin_data() { + + if (plugin_data_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::SummaryMetadata_PluginData>(GetArena()); + plugin_data_ = p; + } + return plugin_data_; +} +inline ::tensorboard::SummaryMetadata_PluginData* SummaryMetadata::mutable_plugin_data() { + // @@protoc_insertion_point(field_mutable:tensorboard.SummaryMetadata.plugin_data) + return _internal_mutable_plugin_data(); +} +inline void SummaryMetadata::set_allocated_plugin_data(::tensorboard::SummaryMetadata_PluginData* plugin_data) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete plugin_data_; + } + if (plugin_data) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(plugin_data); + if (message_arena != submessage_arena) { + plugin_data = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, plugin_data, submessage_arena); + } + + } else { + + } + plugin_data_ = plugin_data; + // @@protoc_insertion_point(field_set_allocated:tensorboard.SummaryMetadata.plugin_data) +} + +// string display_name = 2; +inline void SummaryMetadata::clear_display_name() { + display_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SummaryMetadata::display_name() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryMetadata.display_name) + return _internal_display_name(); +} +inline void SummaryMetadata::set_display_name(const std::string& value) { + _internal_set_display_name(value); + // @@protoc_insertion_point(field_set:tensorboard.SummaryMetadata.display_name) +} +inline std::string* SummaryMetadata::mutable_display_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.SummaryMetadata.display_name) + return _internal_mutable_display_name(); +} +inline const std::string& SummaryMetadata::_internal_display_name() const { + return display_name_.Get(); +} +inline void SummaryMetadata::_internal_set_display_name(const std::string& value) { + + display_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SummaryMetadata::set_display_name(std::string&& value) { + + display_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SummaryMetadata.display_name) +} +inline void SummaryMetadata::set_display_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + display_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SummaryMetadata.display_name) +} +inline void SummaryMetadata::set_display_name(const char* value, + size_t size) { + + display_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SummaryMetadata.display_name) +} +inline std::string* SummaryMetadata::_internal_mutable_display_name() { + + return display_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SummaryMetadata::release_display_name() { + // @@protoc_insertion_point(field_release:tensorboard.SummaryMetadata.display_name) + return display_name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SummaryMetadata::set_allocated_display_name(std::string* display_name) { + if (display_name != nullptr) { + + } else { + + } + display_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), display_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SummaryMetadata.display_name) +} + +// string summary_description = 3; +inline void SummaryMetadata::clear_summary_description() { + summary_description_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& SummaryMetadata::summary_description() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryMetadata.summary_description) + return _internal_summary_description(); +} +inline void SummaryMetadata::set_summary_description(const std::string& value) { + _internal_set_summary_description(value); + // @@protoc_insertion_point(field_set:tensorboard.SummaryMetadata.summary_description) +} +inline std::string* SummaryMetadata::mutable_summary_description() { + // @@protoc_insertion_point(field_mutable:tensorboard.SummaryMetadata.summary_description) + return _internal_mutable_summary_description(); +} +inline const std::string& SummaryMetadata::_internal_summary_description() const { + return summary_description_.Get(); +} +inline void SummaryMetadata::_internal_set_summary_description(const std::string& value) { + + summary_description_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void SummaryMetadata::set_summary_description(std::string&& value) { + + summary_description_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.SummaryMetadata.summary_description) +} +inline void SummaryMetadata::set_summary_description(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + summary_description_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.SummaryMetadata.summary_description) +} +inline void SummaryMetadata::set_summary_description(const char* value, + size_t size) { + + summary_description_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.SummaryMetadata.summary_description) +} +inline std::string* SummaryMetadata::_internal_mutable_summary_description() { + + return summary_description_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* SummaryMetadata::release_summary_description() { + // @@protoc_insertion_point(field_release:tensorboard.SummaryMetadata.summary_description) + return summary_description_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void SummaryMetadata::set_allocated_summary_description(std::string* summary_description) { + if (summary_description != nullptr) { + + } else { + + } + summary_description_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), summary_description, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.SummaryMetadata.summary_description) +} + +// .tensorboard.DataClass data_class = 4; +inline void SummaryMetadata::clear_data_class() { + data_class_ = 0; +} +inline ::tensorboard::DataClass SummaryMetadata::_internal_data_class() const { + return static_cast< ::tensorboard::DataClass >(data_class_); +} +inline ::tensorboard::DataClass SummaryMetadata::data_class() const { + // @@protoc_insertion_point(field_get:tensorboard.SummaryMetadata.data_class) + return _internal_data_class(); +} +inline void SummaryMetadata::_internal_set_data_class(::tensorboard::DataClass value) { + + data_class_ = value; +} +inline void SummaryMetadata::set_data_class(::tensorboard::DataClass value) { + _internal_set_data_class(value); + // @@protoc_insertion_point(field_set:tensorboard.SummaryMetadata.data_class) +} + +// ------------------------------------------------------------------- + +// Summary_Image + +// int32 height = 1; +inline void Summary_Image::clear_height() { + height_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::_internal_height() const { + return height_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::height() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Image.height) + return _internal_height(); +} +inline void Summary_Image::_internal_set_height(::PROTOBUF_NAMESPACE_ID::int32 value) { + + height_ = value; +} +inline void Summary_Image::set_height(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_height(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Image.height) +} + +// int32 width = 2; +inline void Summary_Image::clear_width() { + width_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::_internal_width() const { + return width_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::width() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Image.width) + return _internal_width(); +} +inline void Summary_Image::_internal_set_width(::PROTOBUF_NAMESPACE_ID::int32 value) { + + width_ = value; +} +inline void Summary_Image::set_width(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_width(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Image.width) +} + +// int32 colorspace = 3; +inline void Summary_Image::clear_colorspace() { + colorspace_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::_internal_colorspace() const { + return colorspace_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 Summary_Image::colorspace() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Image.colorspace) + return _internal_colorspace(); +} +inline void Summary_Image::_internal_set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value) { + + colorspace_ = value; +} +inline void Summary_Image::set_colorspace(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_colorspace(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Image.colorspace) +} + +// bytes encoded_image_string = 4; +inline void Summary_Image::clear_encoded_image_string() { + encoded_image_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Summary_Image::encoded_image_string() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Image.encoded_image_string) + return _internal_encoded_image_string(); +} +inline void Summary_Image::set_encoded_image_string(const std::string& value) { + _internal_set_encoded_image_string(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Image.encoded_image_string) +} +inline std::string* Summary_Image::mutable_encoded_image_string() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Image.encoded_image_string) + return _internal_mutable_encoded_image_string(); +} +inline const std::string& Summary_Image::_internal_encoded_image_string() const { + return encoded_image_string_.Get(); +} +inline void Summary_Image::_internal_set_encoded_image_string(const std::string& value) { + + encoded_image_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Image::set_encoded_image_string(std::string&& value) { + + encoded_image_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Summary.Image.encoded_image_string) +} +inline void Summary_Image::set_encoded_image_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + encoded_image_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Summary.Image.encoded_image_string) +} +inline void Summary_Image::set_encoded_image_string(const void* value, + size_t size) { + + encoded_image_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Summary.Image.encoded_image_string) +} +inline std::string* Summary_Image::_internal_mutable_encoded_image_string() { + + return encoded_image_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Image::release_encoded_image_string() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Image.encoded_image_string) + return encoded_image_string_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Image::set_allocated_encoded_image_string(std::string* encoded_image_string) { + if (encoded_image_string != nullptr) { + + } else { + + } + encoded_image_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), encoded_image_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Image.encoded_image_string) +} + +// ------------------------------------------------------------------- + +// Summary_Audio + +// float sample_rate = 1; +inline void Summary_Audio::clear_sample_rate() { + sample_rate_ = 0; +} +inline float Summary_Audio::_internal_sample_rate() const { + return sample_rate_; +} +inline float Summary_Audio::sample_rate() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Audio.sample_rate) + return _internal_sample_rate(); +} +inline void Summary_Audio::_internal_set_sample_rate(float value) { + + sample_rate_ = value; +} +inline void Summary_Audio::set_sample_rate(float value) { + _internal_set_sample_rate(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Audio.sample_rate) +} + +// int64 num_channels = 2; +inline void Summary_Audio::clear_num_channels() { + num_channels_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Audio::_internal_num_channels() const { + return num_channels_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Audio::num_channels() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Audio.num_channels) + return _internal_num_channels(); +} +inline void Summary_Audio::_internal_set_num_channels(::PROTOBUF_NAMESPACE_ID::int64 value) { + + num_channels_ = value; +} +inline void Summary_Audio::set_num_channels(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_num_channels(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Audio.num_channels) +} + +// int64 length_frames = 3; +inline void Summary_Audio::clear_length_frames() { + length_frames_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Audio::_internal_length_frames() const { + return length_frames_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 Summary_Audio::length_frames() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Audio.length_frames) + return _internal_length_frames(); +} +inline void Summary_Audio::_internal_set_length_frames(::PROTOBUF_NAMESPACE_ID::int64 value) { + + length_frames_ = value; +} +inline void Summary_Audio::set_length_frames(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_length_frames(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Audio.length_frames) +} + +// bytes encoded_audio_string = 4; +inline void Summary_Audio::clear_encoded_audio_string() { + encoded_audio_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Summary_Audio::encoded_audio_string() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Audio.encoded_audio_string) + return _internal_encoded_audio_string(); +} +inline void Summary_Audio::set_encoded_audio_string(const std::string& value) { + _internal_set_encoded_audio_string(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Audio.encoded_audio_string) +} +inline std::string* Summary_Audio::mutable_encoded_audio_string() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Audio.encoded_audio_string) + return _internal_mutable_encoded_audio_string(); +} +inline const std::string& Summary_Audio::_internal_encoded_audio_string() const { + return encoded_audio_string_.Get(); +} +inline void Summary_Audio::_internal_set_encoded_audio_string(const std::string& value) { + + encoded_audio_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Audio::set_encoded_audio_string(std::string&& value) { + + encoded_audio_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Summary.Audio.encoded_audio_string) +} +inline void Summary_Audio::set_encoded_audio_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + encoded_audio_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Summary.Audio.encoded_audio_string) +} +inline void Summary_Audio::set_encoded_audio_string(const void* value, + size_t size) { + + encoded_audio_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Summary.Audio.encoded_audio_string) +} +inline std::string* Summary_Audio::_internal_mutable_encoded_audio_string() { + + return encoded_audio_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Audio::release_encoded_audio_string() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Audio.encoded_audio_string) + return encoded_audio_string_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Audio::set_allocated_encoded_audio_string(std::string* encoded_audio_string) { + if (encoded_audio_string != nullptr) { + + } else { + + } + encoded_audio_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), encoded_audio_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Audio.encoded_audio_string) +} + +// string content_type = 5; +inline void Summary_Audio::clear_content_type() { + content_type_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Summary_Audio::content_type() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Audio.content_type) + return _internal_content_type(); +} +inline void Summary_Audio::set_content_type(const std::string& value) { + _internal_set_content_type(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Audio.content_type) +} +inline std::string* Summary_Audio::mutable_content_type() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Audio.content_type) + return _internal_mutable_content_type(); +} +inline const std::string& Summary_Audio::_internal_content_type() const { + return content_type_.Get(); +} +inline void Summary_Audio::_internal_set_content_type(const std::string& value) { + + content_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Audio::set_content_type(std::string&& value) { + + content_type_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Summary.Audio.content_type) +} +inline void Summary_Audio::set_content_type(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + content_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Summary.Audio.content_type) +} +inline void Summary_Audio::set_content_type(const char* value, + size_t size) { + + content_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Summary.Audio.content_type) +} +inline std::string* Summary_Audio::_internal_mutable_content_type() { + + return content_type_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Audio::release_content_type() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Audio.content_type) + return content_type_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Audio::set_allocated_content_type(std::string* content_type) { + if (content_type != nullptr) { + + } else { + + } + content_type_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), content_type, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Audio.content_type) +} + +// ------------------------------------------------------------------- + +// Summary_Value + +// string node_name = 7; +inline void Summary_Value::clear_node_name() { + node_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Summary_Value::node_name() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.node_name) + return _internal_node_name(); +} +inline void Summary_Value::set_node_name(const std::string& value) { + _internal_set_node_name(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Value.node_name) +} +inline std::string* Summary_Value::mutable_node_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.node_name) + return _internal_mutable_node_name(); +} +inline const std::string& Summary_Value::_internal_node_name() const { + return node_name_.Get(); +} +inline void Summary_Value::_internal_set_node_name(const std::string& value) { + + node_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Value::set_node_name(std::string&& value) { + + node_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Summary.Value.node_name) +} +inline void Summary_Value::set_node_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + node_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Summary.Value.node_name) +} +inline void Summary_Value::set_node_name(const char* value, + size_t size) { + + node_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Summary.Value.node_name) +} +inline std::string* Summary_Value::_internal_mutable_node_name() { + + return node_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Value::release_node_name() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.node_name) + return node_name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Value::set_allocated_node_name(std::string* node_name) { + if (node_name != nullptr) { + + } else { + + } + node_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), node_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.node_name) +} + +// string tag = 1; +inline void Summary_Value::clear_tag() { + tag_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& Summary_Value::tag() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.tag) + return _internal_tag(); +} +inline void Summary_Value::set_tag(const std::string& value) { + _internal_set_tag(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Value.tag) +} +inline std::string* Summary_Value::mutable_tag() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.tag) + return _internal_mutable_tag(); +} +inline const std::string& Summary_Value::_internal_tag() const { + return tag_.Get(); +} +inline void Summary_Value::_internal_set_tag(const std::string& value) { + + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Value::set_tag(std::string&& value) { + + tag_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Summary.Value.tag) +} +inline void Summary_Value::set_tag(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Summary.Value.tag) +} +inline void Summary_Value::set_tag(const char* value, + size_t size) { + + tag_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Summary.Value.tag) +} +inline std::string* Summary_Value::_internal_mutable_tag() { + + return tag_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Value::release_tag() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.tag) + return tag_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void Summary_Value::set_allocated_tag(std::string* tag) { + if (tag != nullptr) { + + } else { + + } + tag_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), tag, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.tag) +} + +// .tensorboard.SummaryMetadata metadata = 9; +inline bool Summary_Value::_internal_has_metadata() const { + return this != internal_default_instance() && metadata_ != nullptr; +} +inline bool Summary_Value::has_metadata() const { + return _internal_has_metadata(); +} +inline void Summary_Value::clear_metadata() { + if (GetArena() == nullptr && metadata_ != nullptr) { + delete metadata_; + } + metadata_ = nullptr; +} +inline const ::tensorboard::SummaryMetadata& Summary_Value::_internal_metadata() const { + const ::tensorboard::SummaryMetadata* p = metadata_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_SummaryMetadata_default_instance_); +} +inline const ::tensorboard::SummaryMetadata& Summary_Value::metadata() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.metadata) + return _internal_metadata(); +} +inline void Summary_Value::unsafe_arena_set_allocated_metadata( + ::tensorboard::SummaryMetadata* metadata) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(metadata_); + } + metadata_ = metadata; + if (metadata) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Summary.Value.metadata) +} +inline ::tensorboard::SummaryMetadata* Summary_Value::release_metadata() { + + ::tensorboard::SummaryMetadata* temp = metadata_; + metadata_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::SummaryMetadata* Summary_Value::unsafe_arena_release_metadata() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.metadata) + + ::tensorboard::SummaryMetadata* temp = metadata_; + metadata_ = nullptr; + return temp; +} +inline ::tensorboard::SummaryMetadata* Summary_Value::_internal_mutable_metadata() { + + if (metadata_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::SummaryMetadata>(GetArena()); + metadata_ = p; + } + return metadata_; +} +inline ::tensorboard::SummaryMetadata* Summary_Value::mutable_metadata() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.metadata) + return _internal_mutable_metadata(); +} +inline void Summary_Value::set_allocated_metadata(::tensorboard::SummaryMetadata* metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete metadata_; + } + if (metadata) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(metadata); + if (message_arena != submessage_arena) { + metadata = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, metadata, submessage_arena); + } + + } else { + + } + metadata_ = metadata; + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.metadata) +} + +// float simple_value = 2; +inline bool Summary_Value::_internal_has_simple_value() const { + return value_case() == kSimpleValue; +} +inline void Summary_Value::set_has_simple_value() { + _oneof_case_[0] = kSimpleValue; +} +inline void Summary_Value::clear_simple_value() { + if (_internal_has_simple_value()) { + value_.simple_value_ = 0; + clear_has_value(); + } +} +inline float Summary_Value::_internal_simple_value() const { + if (_internal_has_simple_value()) { + return value_.simple_value_; + } + return 0; +} +inline void Summary_Value::_internal_set_simple_value(float value) { + if (!_internal_has_simple_value()) { + clear_value(); + set_has_simple_value(); + } + value_.simple_value_ = value; +} +inline float Summary_Value::simple_value() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.simple_value) + return _internal_simple_value(); +} +inline void Summary_Value::set_simple_value(float value) { + _internal_set_simple_value(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Value.simple_value) +} + +// bytes obsolete_old_style_histogram = 3; +inline bool Summary_Value::_internal_has_obsolete_old_style_histogram() const { + return value_case() == kObsoleteOldStyleHistogram; +} +inline void Summary_Value::set_has_obsolete_old_style_histogram() { + _oneof_case_[0] = kObsoleteOldStyleHistogram; +} +inline void Summary_Value::clear_obsolete_old_style_histogram() { + if (_internal_has_obsolete_old_style_histogram()) { + value_.obsolete_old_style_histogram_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_value(); + } +} +inline const std::string& Summary_Value::obsolete_old_style_histogram() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.obsolete_old_style_histogram) + return _internal_obsolete_old_style_histogram(); +} +inline void Summary_Value::set_obsolete_old_style_histogram(const std::string& value) { + _internal_set_obsolete_old_style_histogram(value); + // @@protoc_insertion_point(field_set:tensorboard.Summary.Value.obsolete_old_style_histogram) +} +inline std::string* Summary_Value::mutable_obsolete_old_style_histogram() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.obsolete_old_style_histogram) + return _internal_mutable_obsolete_old_style_histogram(); +} +inline const std::string& Summary_Value::_internal_obsolete_old_style_histogram() const { + if (_internal_has_obsolete_old_style_histogram()) { + return value_.obsolete_old_style_histogram_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void Summary_Value::_internal_set_obsolete_old_style_histogram(const std::string& value) { + if (!_internal_has_obsolete_old_style_histogram()) { + clear_value(); + set_has_obsolete_old_style_histogram(); + value_.obsolete_old_style_histogram_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.obsolete_old_style_histogram_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void Summary_Value::set_obsolete_old_style_histogram(std::string&& value) { + // @@protoc_insertion_point(field_set:tensorboard.Summary.Value.obsolete_old_style_histogram) + if (!_internal_has_obsolete_old_style_histogram()) { + clear_value(); + set_has_obsolete_old_style_histogram(); + value_.obsolete_old_style_histogram_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.obsolete_old_style_histogram_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.Summary.Value.obsolete_old_style_histogram) +} +inline void Summary_Value::set_obsolete_old_style_histogram(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_obsolete_old_style_histogram()) { + clear_value(); + set_has_obsolete_old_style_histogram(); + value_.obsolete_old_style_histogram_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.obsolete_old_style_histogram_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.Summary.Value.obsolete_old_style_histogram) +} +inline void Summary_Value::set_obsolete_old_style_histogram(const void* value, + size_t size) { + if (!_internal_has_obsolete_old_style_histogram()) { + clear_value(); + set_has_obsolete_old_style_histogram(); + value_.obsolete_old_style_histogram_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.obsolete_old_style_histogram_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.Summary.Value.obsolete_old_style_histogram) +} +inline std::string* Summary_Value::_internal_mutable_obsolete_old_style_histogram() { + if (!_internal_has_obsolete_old_style_histogram()) { + clear_value(); + set_has_obsolete_old_style_histogram(); + value_.obsolete_old_style_histogram_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return value_.obsolete_old_style_histogram_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* Summary_Value::release_obsolete_old_style_histogram() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.obsolete_old_style_histogram) + if (_internal_has_obsolete_old_style_histogram()) { + clear_has_value(); + return value_.obsolete_old_style_histogram_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void Summary_Value::set_allocated_obsolete_old_style_histogram(std::string* obsolete_old_style_histogram) { + if (has_value()) { + clear_value(); + } + if (obsolete_old_style_histogram != nullptr) { + set_has_obsolete_old_style_histogram(); + value_.obsolete_old_style_histogram_.UnsafeSetDefault(obsolete_old_style_histogram); + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); + if (arena != nullptr) { + arena->Own(obsolete_old_style_histogram); + } + } + // @@protoc_insertion_point(field_set_allocated:tensorboard.Summary.Value.obsolete_old_style_histogram) +} + +// .tensorboard.Summary.Image image = 4; +inline bool Summary_Value::_internal_has_image() const { + return value_case() == kImage; +} +inline bool Summary_Value::has_image() const { + return _internal_has_image(); +} +inline void Summary_Value::set_has_image() { + _oneof_case_[0] = kImage; +} +inline void Summary_Value::clear_image() { + if (_internal_has_image()) { + if (GetArena() == nullptr) { + delete value_.image_; + } + clear_has_value(); + } +} +inline ::tensorboard::Summary_Image* Summary_Value::release_image() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.image) + if (_internal_has_image()) { + clear_has_value(); + ::tensorboard::Summary_Image* temp = value_.image_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.image_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::Summary_Image& Summary_Value::_internal_image() const { + return _internal_has_image() + ? *value_.image_ + : *reinterpret_cast< ::tensorboard::Summary_Image*>(&::tensorboard::_Summary_Image_default_instance_); +} +inline const ::tensorboard::Summary_Image& Summary_Value::image() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.image) + return _internal_image(); +} +inline ::tensorboard::Summary_Image* Summary_Value::unsafe_arena_release_image() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Summary.Value.image) + if (_internal_has_image()) { + clear_has_value(); + ::tensorboard::Summary_Image* temp = value_.image_; + value_.image_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_image(::tensorboard::Summary_Image* image) { + clear_value(); + if (image) { + set_has_image(); + value_.image_ = image; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Summary.Value.image) +} +inline ::tensorboard::Summary_Image* Summary_Value::_internal_mutable_image() { + if (!_internal_has_image()) { + clear_value(); + set_has_image(); + value_.image_ = CreateMaybeMessage< ::tensorboard::Summary_Image >(GetArena()); + } + return value_.image_; +} +inline ::tensorboard::Summary_Image* Summary_Value::mutable_image() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.image) + return _internal_mutable_image(); +} + +// .tensorboard.HistogramProto histo = 5; +inline bool Summary_Value::_internal_has_histo() const { + return value_case() == kHisto; +} +inline bool Summary_Value::has_histo() const { + return _internal_has_histo(); +} +inline void Summary_Value::set_has_histo() { + _oneof_case_[0] = kHisto; +} +inline ::tensorboard::HistogramProto* Summary_Value::release_histo() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.histo) + if (_internal_has_histo()) { + clear_has_value(); + ::tensorboard::HistogramProto* temp = value_.histo_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.histo_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::HistogramProto& Summary_Value::_internal_histo() const { + return _internal_has_histo() + ? *value_.histo_ + : *reinterpret_cast< ::tensorboard::HistogramProto*>(&::tensorboard::_HistogramProto_default_instance_); +} +inline const ::tensorboard::HistogramProto& Summary_Value::histo() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.histo) + return _internal_histo(); +} +inline ::tensorboard::HistogramProto* Summary_Value::unsafe_arena_release_histo() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Summary.Value.histo) + if (_internal_has_histo()) { + clear_has_value(); + ::tensorboard::HistogramProto* temp = value_.histo_; + value_.histo_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_histo(::tensorboard::HistogramProto* histo) { + clear_value(); + if (histo) { + set_has_histo(); + value_.histo_ = histo; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Summary.Value.histo) +} +inline ::tensorboard::HistogramProto* Summary_Value::_internal_mutable_histo() { + if (!_internal_has_histo()) { + clear_value(); + set_has_histo(); + value_.histo_ = CreateMaybeMessage< ::tensorboard::HistogramProto >(GetArena()); + } + return value_.histo_; +} +inline ::tensorboard::HistogramProto* Summary_Value::mutable_histo() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.histo) + return _internal_mutable_histo(); +} + +// .tensorboard.Summary.Audio audio = 6; +inline bool Summary_Value::_internal_has_audio() const { + return value_case() == kAudio; +} +inline bool Summary_Value::has_audio() const { + return _internal_has_audio(); +} +inline void Summary_Value::set_has_audio() { + _oneof_case_[0] = kAudio; +} +inline void Summary_Value::clear_audio() { + if (_internal_has_audio()) { + if (GetArena() == nullptr) { + delete value_.audio_; + } + clear_has_value(); + } +} +inline ::tensorboard::Summary_Audio* Summary_Value::release_audio() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.audio) + if (_internal_has_audio()) { + clear_has_value(); + ::tensorboard::Summary_Audio* temp = value_.audio_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.audio_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::Summary_Audio& Summary_Value::_internal_audio() const { + return _internal_has_audio() + ? *value_.audio_ + : *reinterpret_cast< ::tensorboard::Summary_Audio*>(&::tensorboard::_Summary_Audio_default_instance_); +} +inline const ::tensorboard::Summary_Audio& Summary_Value::audio() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.audio) + return _internal_audio(); +} +inline ::tensorboard::Summary_Audio* Summary_Value::unsafe_arena_release_audio() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Summary.Value.audio) + if (_internal_has_audio()) { + clear_has_value(); + ::tensorboard::Summary_Audio* temp = value_.audio_; + value_.audio_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_audio(::tensorboard::Summary_Audio* audio) { + clear_value(); + if (audio) { + set_has_audio(); + value_.audio_ = audio; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Summary.Value.audio) +} +inline ::tensorboard::Summary_Audio* Summary_Value::_internal_mutable_audio() { + if (!_internal_has_audio()) { + clear_value(); + set_has_audio(); + value_.audio_ = CreateMaybeMessage< ::tensorboard::Summary_Audio >(GetArena()); + } + return value_.audio_; +} +inline ::tensorboard::Summary_Audio* Summary_Value::mutable_audio() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.audio) + return _internal_mutable_audio(); +} + +// .tensorboard.TensorProto tensor = 8; +inline bool Summary_Value::_internal_has_tensor() const { + return value_case() == kTensor; +} +inline bool Summary_Value::has_tensor() const { + return _internal_has_tensor(); +} +inline void Summary_Value::set_has_tensor() { + _oneof_case_[0] = kTensor; +} +inline ::tensorboard::TensorProto* Summary_Value::release_tensor() { + // @@protoc_insertion_point(field_release:tensorboard.Summary.Value.tensor) + if (_internal_has_tensor()) { + clear_has_value(); + ::tensorboard::TensorProto* temp = value_.tensor_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.tensor_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::tensorboard::TensorProto& Summary_Value::_internal_tensor() const { + return _internal_has_tensor() + ? *value_.tensor_ + : *reinterpret_cast< ::tensorboard::TensorProto*>(&::tensorboard::_TensorProto_default_instance_); +} +inline const ::tensorboard::TensorProto& Summary_Value::tensor() const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.Value.tensor) + return _internal_tensor(); +} +inline ::tensorboard::TensorProto* Summary_Value::unsafe_arena_release_tensor() { + // @@protoc_insertion_point(field_unsafe_arena_release:tensorboard.Summary.Value.tensor) + if (_internal_has_tensor()) { + clear_has_value(); + ::tensorboard::TensorProto* temp = value_.tensor_; + value_.tensor_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void Summary_Value::unsafe_arena_set_allocated_tensor(::tensorboard::TensorProto* tensor) { + clear_value(); + if (tensor) { + set_has_tensor(); + value_.tensor_ = tensor; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.Summary.Value.tensor) +} +inline ::tensorboard::TensorProto* Summary_Value::_internal_mutable_tensor() { + if (!_internal_has_tensor()) { + clear_value(); + set_has_tensor(); + value_.tensor_ = CreateMaybeMessage< ::tensorboard::TensorProto >(GetArena()); + } + return value_.tensor_; +} +inline ::tensorboard::TensorProto* Summary_Value::mutable_tensor() { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.Value.tensor) + return _internal_mutable_tensor(); +} + +inline bool Summary_Value::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void Summary_Value::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline Summary_Value::ValueCase Summary_Value::value_case() const { + return Summary_Value::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// Summary + +// repeated .tensorboard.Summary.Value value = 1; +inline int Summary::_internal_value_size() const { + return value_.size(); +} +inline int Summary::value_size() const { + return _internal_value_size(); +} +inline void Summary::clear_value() { + value_.Clear(); +} +inline ::tensorboard::Summary_Value* Summary::mutable_value(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.Summary.value) + return value_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Summary_Value >* +Summary::mutable_value() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.Summary.value) + return &value_; +} +inline const ::tensorboard::Summary_Value& Summary::_internal_value(int index) const { + return value_.Get(index); +} +inline const ::tensorboard::Summary_Value& Summary::value(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.Summary.value) + return _internal_value(index); +} +inline ::tensorboard::Summary_Value* Summary::_internal_add_value() { + return value_.Add(); +} +inline ::tensorboard::Summary_Value* Summary::add_value() { + // @@protoc_insertion_point(field_add:tensorboard.Summary.value) + return _internal_add_value(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::Summary_Value >& +Summary::value() const { + // @@protoc_insertion_point(field_list:tensorboard.Summary.value) + return value_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< ::tensorboard::DataClass> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::tensorboard::DataClass>() { + return ::tensorboard::DataClass_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_summary_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/summary.proto b/plugins/mindstudio-insight-plugins/proto/summary.proto new file mode 100644 index 0000000000000000000000000000000000000000..734d5e936384517a023d18bd3fa015e0bb5b67dd --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/summary.proto @@ -0,0 +1,132 @@ +syntax = "proto3"; + +package tensorboard; + +import public "histogram.proto"; +import "tensor.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "SummaryProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/summary_go_proto"; + +// Metadata associated with a series of Summary data +message SummaryDescription { + // Hint on how plugins should process the data in this series. + // Supported values include "scalar", "histogram", "image", "audio" + string type_hint = 1; +} + +// A SummaryMetadata encapsulates information on which plugins are able to make +// use of a certain summary value. +message SummaryMetadata { + message PluginData { + // The name of the plugin this data pertains to. + string plugin_name = 1; + + // The content to store for the plugin. The best practice is for this to be + // a binary serialized protocol buffer. + bytes content = 2; + } + + // Data that associates a summary with a certain plugin. + PluginData plugin_data = 1; + + // Display name for viewing in TensorBoard. + string display_name = 2; + + // Longform readable description of the summary sequence. Markdown supported. + string summary_description = 3; + + // Class of data stored in this time series. Required for compatibility with + // TensorBoard's generic data facilities (`DataProvider`, et al.). This value + // imposes constraints on the dtype and shape of the corresponding tensor + // values. See `DataClass` docs for details. + DataClass data_class = 4; +} + +enum DataClass { + // Unknown data class, used (implicitly) for legacy data. Will not be + // processed by data ingestion pipelines. + DATA_CLASS_UNKNOWN = 0; + // Scalar time series. Each `Value` for the corresponding tag must have + // `tensor` set to a rank-0 tensor of type `DT_FLOAT` (float32). + DATA_CLASS_SCALAR = 1; + // Tensor time series. Each `Value` for the corresponding tag must have + // `tensor` set. The tensor value is arbitrary, but should be small to + // accommodate direct storage in database backends: an upper bound of a few + // kilobytes is a reasonable rule of thumb. + DATA_CLASS_TENSOR = 2; + // Blob sequence time series. Each `Value` for the corresponding tag must + // have `tensor` set to a rank-1 tensor of bytestring dtype. + DATA_CLASS_BLOB_SEQUENCE = 3; +} + +// A Summary is a set of named values to be displayed by the +// visualizer. +// +// Summaries are produced regularly during training, as controlled by +// the "summary_interval_secs" attribute of the training operation. +// Summaries are also produced at the end of an evaluation. +message Summary { + message Image { + // Dimensions of the image. + int32 height = 1; + int32 width = 2; + // Valid colorspace values are + // 1 - grayscale + // 2 - grayscale + alpha + // 3 - RGB + // 4 - RGBA + // 5 - DIGITAL_YUV + // 6 - BGRA + int32 colorspace = 3; + // Image data in encoded format. All image formats supported by + // image_codec::CoderUtil can be stored here. + bytes encoded_image_string = 4; + } + + message Audio { + // Sample rate of the audio in Hz. + float sample_rate = 1; + // Number of channels of audio. + int64 num_channels = 2; + // Length of the audio in frames (samples per channel). + int64 length_frames = 3; + // Encoded audio data and its associated RFC 2045 content type (e.g. + // "audio/wav"). + bytes encoded_audio_string = 4; + string content_type = 5; + } + + message Value { + // This field is deprecated and will not be set. + string node_name = 7; + + // Tag name for the data. Used by TensorBoard plugins to organize data. Tags + // are often organized by scope (which contains slashes to convey + // hierarchy). For example: foo/bar/0 + string tag = 1; + + // Contains metadata on the summary value such as which plugins may use it. + // Take note that many summary values may lack a metadata field. This is + // because the FileWriter only keeps a metadata object on the first summary + // value with a certain tag for each tag. TensorBoard then remembers which + // tags are associated with which plugins. This saves space. + SummaryMetadata metadata = 9; + + // Value associated with the tag. + oneof value { + float simple_value = 2; + bytes obsolete_old_style_histogram = 3; + Image image = 4; + HistogramProto histo = 5; + Audio audio = 6; + TensorProto tensor = 8; + } + } + + // Set of values for the summary. + repeated Value value = 1; +} diff --git a/plugins/mindstudio-insight-plugins/proto/tensor.pb.cc b/plugins/mindstudio-insight-plugins/proto/tensor.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..3de94a97ee1ea0d3bdaa39ebe4bec76d6f89ce3b --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/tensor.pb.cc @@ -0,0 +1,1245 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor.proto + +#include "tensor.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_resource_5fhandle_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_ResourceHandleProto_resource_5fhandle_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_tensor_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_TensorProto_tensor_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_tensor_5fshape_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_tensor_5fshape_2eproto; +namespace tensorboard { +class TensorProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorProto_default_instance_; +class VariantTensorDataProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _VariantTensorDataProto_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_TensorProto_tensor_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_TensorProto_default_instance_; + new (ptr) ::tensorboard::TensorProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::tensorboard::_VariantTensorDataProto_default_instance_; + new (ptr) ::tensorboard::VariantTensorDataProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::TensorProto::InitAsDefaultInstance(); + ::tensorboard::VariantTensorDataProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_TensorProto_tensor_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 2, 0, InitDefaultsscc_info_TensorProto_tensor_2eproto}, { + &scc_info_TensorShapeProto_tensor_5fshape_2eproto.base, + &scc_info_ResourceHandleProto_resource_5fhandle_2eproto.base,}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_tensor_2eproto[2]; +static constexpr ::PROTOBUF_NAMESPACE_ID::EnumDescriptor const** file_level_enum_descriptors_tensor_2eproto = nullptr; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_tensor_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_tensor_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, dtype_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, tensor_shape_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, version_number_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, tensor_content_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, half_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, float_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, double_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, int_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, string_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, scomplex_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, int64_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, bool_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, dcomplex_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, resource_handle_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, variant_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, uint32_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, uint64_val_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorProto, float8_val_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::VariantTensorDataProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::VariantTensorDataProto, type_name_), + PROTOBUF_FIELD_OFFSET(::tensorboard::VariantTensorDataProto, metadata_), + PROTOBUF_FIELD_OFFSET(::tensorboard::VariantTensorDataProto, tensors_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::TensorProto)}, + { 23, -1, sizeof(::tensorboard::VariantTensorDataProto)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_TensorProto_default_instance_), + reinterpret_cast(&::tensorboard::_VariantTensorDataProto_default_instance_), +}; + +const char descriptor_table_protodef_tensor_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\014tensor.proto\022\013tensorboard\032\025resource_ha" + "ndle.proto\032\022tensor_shape.proto\032\013types.pr" + "oto\"\244\004\n\013TensorProto\022$\n\005dtype\030\001 \001(\0162\025.ten" + "sorboard.DataType\0223\n\014tensor_shape\030\002 \001(\0132" + "\035.tensorboard.TensorShapeProto\022\026\n\016versio" + "n_number\030\003 \001(\005\022\026\n\016tensor_content\030\004 \001(\014\022\024" + "\n\010half_val\030\r \003(\005B\002\020\001\022\025\n\tfloat_val\030\005 \003(\002B" + "\002\020\001\022\026\n\ndouble_val\030\006 \003(\001B\002\020\001\022\023\n\007int_val\030\007" + " \003(\005B\002\020\001\022\022\n\nstring_val\030\010 \003(\014\022\030\n\014scomplex" + "_val\030\t \003(\002B\002\020\001\022\025\n\tint64_val\030\n \003(\003B\002\020\001\022\024\n" + "\010bool_val\030\013 \003(\010B\002\020\001\022\030\n\014dcomplex_val\030\014 \003(" + "\001B\002\020\001\022=\n\023resource_handle_val\030\016 \003(\0132 .ten" + "sorboard.ResourceHandleProto\0228\n\013variant_" + "val\030\017 \003(\0132#.tensorboard.VariantTensorDat" + "aProto\022\026\n\nuint32_val\030\020 \003(\rB\002\020\001\022\026\n\nuint64" + "_val\030\021 \003(\004B\002\020\001\022\022\n\nfloat8_val\030\022 \001(\014\"h\n\026Va" + "riantTensorDataProto\022\021\n\ttype_name\030\001 \001(\t\022" + "\020\n\010metadata\030\002 \001(\014\022)\n\007tensors\030\003 \003(\0132\030.ten" + "sorboard.TensorProtoB|\n\030org.tensorflow.f" + "rameworkB\014TensorProtosP\001ZMgithub.com/ten" + "sorflow/tensorflow/tensorflow/go/core/fr" + "amework/tensor_go_proto\370\001\001b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_tensor_2eproto_deps[3] = { + &::descriptor_table_resource_5fhandle_2eproto, + &::descriptor_table_tensor_5fshape_2eproto, + &::descriptor_table_types_2eproto, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_tensor_2eproto_sccs[1] = { + &scc_info_TensorProto_tensor_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_tensor_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_tensor_2eproto = { + false, false, descriptor_table_protodef_tensor_2eproto, "tensor.proto", 874, + &descriptor_table_tensor_2eproto_once, descriptor_table_tensor_2eproto_sccs, descriptor_table_tensor_2eproto_deps, 1, 3, + schemas, file_default_instances, TableStruct_tensor_2eproto::offsets, + file_level_metadata_tensor_2eproto, 2, file_level_enum_descriptors_tensor_2eproto, file_level_service_descriptors_tensor_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_tensor_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_tensor_2eproto)), true); +namespace tensorboard { + +// =================================================================== + +void TensorProto::InitAsDefaultInstance() { + ::tensorboard::_TensorProto_default_instance_._instance.get_mutable()->tensor_shape_ = const_cast< ::tensorboard::TensorShapeProto*>( + ::tensorboard::TensorShapeProto::internal_default_instance()); +} +class TensorProto::_Internal { + public: + static const ::tensorboard::TensorShapeProto& tensor_shape(const TensorProto* msg); +}; + +const ::tensorboard::TensorShapeProto& +TensorProto::_Internal::tensor_shape(const TensorProto* msg) { + return *msg->tensor_shape_; +} +void TensorProto::clear_tensor_shape() { + if (GetArena() == nullptr && tensor_shape_ != nullptr) { + delete tensor_shape_; + } + tensor_shape_ = nullptr; +} +void TensorProto::clear_resource_handle_val() { + resource_handle_val_.Clear(); +} +TensorProto::TensorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + float_val_(arena), + double_val_(arena), + int_val_(arena), + string_val_(arena), + scomplex_val_(arena), + int64_val_(arena), + bool_val_(arena), + dcomplex_val_(arena), + half_val_(arena), + resource_handle_val_(arena), + variant_val_(arena), + uint32_val_(arena), + uint64_val_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.TensorProto) +} +TensorProto::TensorProto(const TensorProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + float_val_(from.float_val_), + double_val_(from.double_val_), + int_val_(from.int_val_), + string_val_(from.string_val_), + scomplex_val_(from.scomplex_val_), + int64_val_(from.int64_val_), + bool_val_(from.bool_val_), + dcomplex_val_(from.dcomplex_val_), + half_val_(from.half_val_), + resource_handle_val_(from.resource_handle_val_), + variant_val_(from.variant_val_), + uint32_val_(from.uint32_val_), + uint64_val_(from.uint64_val_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + tensor_content_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_tensor_content().empty()) { + tensor_content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_tensor_content(), + GetArena()); + } + float8_val_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_float8_val().empty()) { + float8_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_float8_val(), + GetArena()); + } + if (from._internal_has_tensor_shape()) { + tensor_shape_ = new ::tensorboard::TensorShapeProto(*from.tensor_shape_); + } else { + tensor_shape_ = nullptr; + } + ::memcpy(&dtype_, &from.dtype_, + static_cast(reinterpret_cast(&version_number_) - + reinterpret_cast(&dtype_)) + sizeof(version_number_)); + // @@protoc_insertion_point(copy_constructor:tensorboard.TensorProto) +} + +void TensorProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorProto_tensor_2eproto.base); + tensor_content_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + float8_val_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&tensor_shape_, 0, static_cast( + reinterpret_cast(&version_number_) - + reinterpret_cast(&tensor_shape_)) + sizeof(version_number_)); +} + +TensorProto::~TensorProto() { + // @@protoc_insertion_point(destructor:tensorboard.TensorProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + tensor_content_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + float8_val_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete tensor_shape_; +} + +void TensorProto::ArenaDtor(void* object) { + TensorProto* _this = reinterpret_cast< TensorProto* >(object); + (void)_this; +} +void TensorProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorProto& TensorProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorProto_tensor_2eproto.base); + return *internal_default_instance(); +} + + +void TensorProto::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.TensorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + float_val_.Clear(); + double_val_.Clear(); + int_val_.Clear(); + string_val_.Clear(); + scomplex_val_.Clear(); + int64_val_.Clear(); + bool_val_.Clear(); + dcomplex_val_.Clear(); + half_val_.Clear(); + resource_handle_val_.Clear(); + variant_val_.Clear(); + uint32_val_.Clear(); + uint64_val_.Clear(); + tensor_content_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + float8_val_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + if (GetArena() == nullptr && tensor_shape_ != nullptr) { + delete tensor_shape_; + } + tensor_shape_ = nullptr; + ::memset(&dtype_, 0, static_cast( + reinterpret_cast(&version_number_) - + reinterpret_cast(&dtype_)) + sizeof(version_number_)); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.DataType dtype = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_dtype(static_cast<::tensorboard::DataType>(val)); + } else goto handle_unusual; + continue; + // .tensorboard.TensorShapeProto tensor_shape = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_tensor_shape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // int32 version_number = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + version_number_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes tensor_content = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_tensor_content(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float float_val = 5 [packed = true]; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_float_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 45) { + _internal_add_float_val(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // repeated double double_val = 6 [packed = true]; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_double_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 49) { + _internal_add_double_val(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // repeated int32 int_val = 7 [packed = true]; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_int_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 56) { + _internal_add_int_val(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated bytes string_val = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_string_val(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<66>(ptr)); + } else goto handle_unusual; + continue; + // repeated float scomplex_val = 9 [packed = true]; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_scomplex_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 77) { + _internal_add_scomplex_val(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // repeated int64 int64_val = 10 [packed = true]; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_int64_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 80) { + _internal_add_int64_val(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated bool bool_val = 11 [packed = true]; + case 11: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 90)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedBoolParser(_internal_mutable_bool_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 88) { + _internal_add_bool_val(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated double dcomplex_val = 12 [packed = true]; + case 12: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 98)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_dcomplex_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 97) { + _internal_add_dcomplex_val(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // repeated int32 half_val = 13 [packed = true]; + case 13: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 106)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_half_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 104) { + _internal_add_half_val(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .tensorboard.ResourceHandleProto resource_handle_val = 14; + case 14: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 114)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_resource_handle_val(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<114>(ptr)); + } else goto handle_unusual; + continue; + // repeated .tensorboard.VariantTensorDataProto variant_val = 15; + case 15: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 122)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_variant_val(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<122>(ptr)); + } else goto handle_unusual; + continue; + // repeated uint32 uint32_val = 16 [packed = true]; + case 16: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 130)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedUInt32Parser(_internal_mutable_uint32_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 128) { + _internal_add_uint32_val(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated uint64 uint64_val = 17 [packed = true]; + case 17: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 138)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedUInt64Parser(_internal_mutable_uint64_val(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 136) { + _internal_add_uint64_val(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes float8_val = 18; + case 18: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 146)) { + auto str = _internal_mutable_float8_val(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.TensorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.DataType dtype = 1; + if (this->dtype() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_dtype(), target); + } + + // .tensorboard.TensorShapeProto tensor_shape = 2; + if (this->has_tensor_shape()) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::tensor_shape(this), target, stream); + } + + // int32 version_number = 3; + if (this->version_number() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt32ToArray(3, this->_internal_version_number(), target); + } + + // bytes tensor_content = 4; + if (this->tensor_content().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 4, this->_internal_tensor_content(), target); + } + + // repeated float float_val = 5 [packed = true]; + if (this->_internal_float_val_size() > 0) { + target = stream->WriteFixedPacked(5, _internal_float_val(), target); + } + + // repeated double double_val = 6 [packed = true]; + if (this->_internal_double_val_size() > 0) { + target = stream->WriteFixedPacked(6, _internal_double_val(), target); + } + + // repeated int32 int_val = 7 [packed = true]; + { + int byte_size = _int_val_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt32Packed( + 7, _internal_int_val(), byte_size, target); + } + } + + // repeated bytes string_val = 8; + for (int i = 0, n = this->_internal_string_val_size(); i < n; i++) { + const auto& s = this->_internal_string_val(i); + target = stream->WriteBytes(8, s, target); + } + + // repeated float scomplex_val = 9 [packed = true]; + if (this->_internal_scomplex_val_size() > 0) { + target = stream->WriteFixedPacked(9, _internal_scomplex_val(), target); + } + + // repeated int64 int64_val = 10 [packed = true]; + { + int byte_size = _int64_val_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt64Packed( + 10, _internal_int64_val(), byte_size, target); + } + } + + // repeated bool bool_val = 11 [packed = true]; + if (this->_internal_bool_val_size() > 0) { + target = stream->WriteFixedPacked(11, _internal_bool_val(), target); + } + + // repeated double dcomplex_val = 12 [packed = true]; + if (this->_internal_dcomplex_val_size() > 0) { + target = stream->WriteFixedPacked(12, _internal_dcomplex_val(), target); + } + + // repeated int32 half_val = 13 [packed = true]; + { + int byte_size = _half_val_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt32Packed( + 13, _internal_half_val(), byte_size, target); + } + } + + // repeated .tensorboard.ResourceHandleProto resource_handle_val = 14; + for (unsigned int i = 0, + n = static_cast(this->_internal_resource_handle_val_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(14, this->_internal_resource_handle_val(i), target, stream); + } + + // repeated .tensorboard.VariantTensorDataProto variant_val = 15; + for (unsigned int i = 0, + n = static_cast(this->_internal_variant_val_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(15, this->_internal_variant_val(i), target, stream); + } + + // repeated uint32 uint32_val = 16 [packed = true]; + { + int byte_size = _uint32_val_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteUInt32Packed( + 16, _internal_uint32_val(), byte_size, target); + } + } + + // repeated uint64 uint64_val = 17 [packed = true]; + { + int byte_size = _uint64_val_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteUInt64Packed( + 17, _internal_uint64_val(), byte_size, target); + } + } + + // bytes float8_val = 18; + if (this->float8_val().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 18, this->_internal_float8_val(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.TensorProto) + return target; +} + +size_t TensorProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.TensorProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated float float_val = 5 [packed = true]; + { + unsigned int count = static_cast(this->_internal_float_val_size()); + size_t data_size = 4UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _float_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated double double_val = 6 [packed = true]; + { + unsigned int count = static_cast(this->_internal_double_val_size()); + size_t data_size = 8UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _double_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated int32 int_val = 7 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->int_val_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _int_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated bytes string_val = 8; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(string_val_.size()); + for (int i = 0, n = string_val_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + string_val_.Get(i)); + } + + // repeated float scomplex_val = 9 [packed = true]; + { + unsigned int count = static_cast(this->_internal_scomplex_val_size()); + size_t data_size = 4UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _scomplex_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated int64 int64_val = 10 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->int64_val_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _int64_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated bool bool_val = 11 [packed = true]; + { + unsigned int count = static_cast(this->_internal_bool_val_size()); + size_t data_size = 1UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _bool_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated double dcomplex_val = 12 [packed = true]; + { + unsigned int count = static_cast(this->_internal_dcomplex_val_size()); + size_t data_size = 8UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _dcomplex_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated int32 half_val = 13 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->half_val_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _half_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated .tensorboard.ResourceHandleProto resource_handle_val = 14; + total_size += 1UL * this->_internal_resource_handle_val_size(); + for (const auto& msg : this->resource_handle_val_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .tensorboard.VariantTensorDataProto variant_val = 15; + total_size += 1UL * this->_internal_variant_val_size(); + for (const auto& msg : this->variant_val_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated uint32 uint32_val = 16 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + UInt32Size(this->uint32_val_); + if (data_size > 0) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _uint32_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated uint64 uint64_val = 17 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + UInt64Size(this->uint64_val_); + if (data_size > 0) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _uint64_val_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // bytes tensor_content = 4; + if (this->tensor_content().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_tensor_content()); + } + + // bytes float8_val = 18; + if (this->float8_val().size() > 0) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_float8_val()); + } + + // .tensorboard.TensorShapeProto tensor_shape = 2; + if (this->has_tensor_shape()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *tensor_shape_); + } + + // .tensorboard.DataType dtype = 1; + if (this->dtype() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_dtype()); + } + + // int32 version_number = 3; + if (this->version_number() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + this->_internal_version_number()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.TensorProto) + GOOGLE_DCHECK_NE(&from, this); + const TensorProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.TensorProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.TensorProto) + MergeFrom(*source); + } +} + +void TensorProto::MergeFrom(const TensorProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.TensorProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + float_val_.MergeFrom(from.float_val_); + double_val_.MergeFrom(from.double_val_); + int_val_.MergeFrom(from.int_val_); + string_val_.MergeFrom(from.string_val_); + scomplex_val_.MergeFrom(from.scomplex_val_); + int64_val_.MergeFrom(from.int64_val_); + bool_val_.MergeFrom(from.bool_val_); + dcomplex_val_.MergeFrom(from.dcomplex_val_); + half_val_.MergeFrom(from.half_val_); + resource_handle_val_.MergeFrom(from.resource_handle_val_); + variant_val_.MergeFrom(from.variant_val_); + uint32_val_.MergeFrom(from.uint32_val_); + uint64_val_.MergeFrom(from.uint64_val_); + if (from.tensor_content().size() > 0) { + _internal_set_tensor_content(from._internal_tensor_content()); + } + if (from.float8_val().size() > 0) { + _internal_set_float8_val(from._internal_float8_val()); + } + if (from.has_tensor_shape()) { + _internal_mutable_tensor_shape()->::tensorboard::TensorShapeProto::MergeFrom(from._internal_tensor_shape()); + } + if (from.dtype() != 0) { + _internal_set_dtype(from._internal_dtype()); + } + if (from.version_number() != 0) { + _internal_set_version_number(from._internal_version_number()); + } +} + +void TensorProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.TensorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorProto::CopyFrom(const TensorProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.TensorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorProto::IsInitialized() const { + return true; +} + +void TensorProto::InternalSwap(TensorProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + float_val_.InternalSwap(&other->float_val_); + double_val_.InternalSwap(&other->double_val_); + int_val_.InternalSwap(&other->int_val_); + string_val_.InternalSwap(&other->string_val_); + scomplex_val_.InternalSwap(&other->scomplex_val_); + int64_val_.InternalSwap(&other->int64_val_); + bool_val_.InternalSwap(&other->bool_val_); + dcomplex_val_.InternalSwap(&other->dcomplex_val_); + half_val_.InternalSwap(&other->half_val_); + resource_handle_val_.InternalSwap(&other->resource_handle_val_); + variant_val_.InternalSwap(&other->variant_val_); + uint32_val_.InternalSwap(&other->uint32_val_); + uint64_val_.InternalSwap(&other->uint64_val_); + tensor_content_.Swap(&other->tensor_content_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + float8_val_.Swap(&other->float8_val_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TensorProto, version_number_) + + sizeof(TensorProto::version_number_) + - PROTOBUF_FIELD_OFFSET(TensorProto, tensor_shape_)>( + reinterpret_cast(&tensor_shape_), + reinterpret_cast(&other->tensor_shape_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void VariantTensorDataProto::InitAsDefaultInstance() { +} +class VariantTensorDataProto::_Internal { + public: +}; + +VariantTensorDataProto::VariantTensorDataProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + tensors_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.VariantTensorDataProto) +} +VariantTensorDataProto::VariantTensorDataProto(const VariantTensorDataProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + tensors_(from.tensors_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + type_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_type_name().empty()) { + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_type_name(), + GetArena()); + } + metadata_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_metadata().empty()) { + metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_metadata(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:tensorboard.VariantTensorDataProto) +} + +void VariantTensorDataProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorProto_tensor_2eproto.base); + type_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + metadata_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +VariantTensorDataProto::~VariantTensorDataProto() { + // @@protoc_insertion_point(destructor:tensorboard.VariantTensorDataProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void VariantTensorDataProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + type_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + metadata_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void VariantTensorDataProto::ArenaDtor(void* object) { + VariantTensorDataProto* _this = reinterpret_cast< VariantTensorDataProto* >(object); + (void)_this; +} +void VariantTensorDataProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void VariantTensorDataProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const VariantTensorDataProto& VariantTensorDataProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorProto_tensor_2eproto.base); + return *internal_default_instance(); +} + + +void VariantTensorDataProto::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.VariantTensorDataProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + tensors_.Clear(); + type_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + metadata_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* VariantTensorDataProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // string type_name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_type_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.VariantTensorDataProto.type_name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + // bytes metadata = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_metadata(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .tensorboard.TensorProto tensors = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_tensors(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<26>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* VariantTensorDataProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.VariantTensorDataProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string type_name = 1; + if (this->type_name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_type_name().data(), static_cast(this->_internal_type_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.VariantTensorDataProto.type_name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_type_name(), target); + } + + // bytes metadata = 2; + if (this->metadata().size() > 0) { + target = stream->WriteBytesMaybeAliased( + 2, this->_internal_metadata(), target); + } + + // repeated .tensorboard.TensorProto tensors = 3; + for (unsigned int i = 0, + n = static_cast(this->_internal_tensors_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(3, this->_internal_tensors(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.VariantTensorDataProto) + return target; +} + +size_t VariantTensorDataProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.VariantTensorDataProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .tensorboard.TensorProto tensors = 3; + total_size += 1UL * this->_internal_tensors_size(); + for (const auto& msg : this->tensors_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // string type_name = 1; + if (this->type_name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_type_name()); + } + + // bytes metadata = 2; + if (this->metadata().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_metadata()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void VariantTensorDataProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.VariantTensorDataProto) + GOOGLE_DCHECK_NE(&from, this); + const VariantTensorDataProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.VariantTensorDataProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.VariantTensorDataProto) + MergeFrom(*source); + } +} + +void VariantTensorDataProto::MergeFrom(const VariantTensorDataProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.VariantTensorDataProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + tensors_.MergeFrom(from.tensors_); + if (from.type_name().size() > 0) { + _internal_set_type_name(from._internal_type_name()); + } + if (from.metadata().size() > 0) { + _internal_set_metadata(from._internal_metadata()); + } +} + +void VariantTensorDataProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.VariantTensorDataProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void VariantTensorDataProto::CopyFrom(const VariantTensorDataProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.VariantTensorDataProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool VariantTensorDataProto::IsInitialized() const { + return true; +} + +void VariantTensorDataProto::InternalSwap(VariantTensorDataProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + tensors_.InternalSwap(&other->tensors_); + type_name_.Swap(&other->type_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + metadata_.Swap(&other->metadata_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata VariantTensorDataProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::TensorProto* Arena::CreateMaybeMessage< ::tensorboard::TensorProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::TensorProto >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::VariantTensorDataProto* Arena::CreateMaybeMessage< ::tensorboard::VariantTensorDataProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::VariantTensorDataProto >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/tensor.pb.h b/plugins/mindstudio-insight-plugins/proto/tensor.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..40c723941fc3b9840552ea36b3b773f11d05ca6e --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/tensor.pb.h @@ -0,0 +1,1827 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_tensor_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_tensor_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include "resource_handle.pb.h" +#include "tensor_shape.pb.h" +#include "types.pb.h" +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_tensor_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_tensor_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[2] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_tensor_2eproto; +namespace tensorboard { +class TensorProto; +class TensorProtoDefaultTypeInternal; +extern TensorProtoDefaultTypeInternal _TensorProto_default_instance_; +class VariantTensorDataProto; +class VariantTensorDataProtoDefaultTypeInternal; +extern VariantTensorDataProtoDefaultTypeInternal _VariantTensorDataProto_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::TensorProto* Arena::CreateMaybeMessage<::tensorboard::TensorProto>(Arena*); +template<> ::tensorboard::VariantTensorDataProto* Arena::CreateMaybeMessage<::tensorboard::VariantTensorDataProto>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +// =================================================================== + +class TensorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.TensorProto) */ { + public: + inline TensorProto() : TensorProto(nullptr) {} + virtual ~TensorProto(); + + TensorProto(const TensorProto& from); + TensorProto(TensorProto&& from) noexcept + : TensorProto() { + *this = ::std::move(from); + } + + inline TensorProto& operator=(const TensorProto& from) { + CopyFrom(from); + return *this; + } + inline TensorProto& operator=(TensorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorProto* internal_default_instance() { + return reinterpret_cast( + &_TensorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(TensorProto& a, TensorProto& b) { + a.Swap(&b); + } + inline void Swap(TensorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorProto& from); + void MergeFrom(const TensorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.TensorProto"; + } + protected: + explicit TensorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_tensor_2eproto); + return ::descriptor_table_tensor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kFloatValFieldNumber = 5, + kDoubleValFieldNumber = 6, + kIntValFieldNumber = 7, + kStringValFieldNumber = 8, + kScomplexValFieldNumber = 9, + kInt64ValFieldNumber = 10, + kBoolValFieldNumber = 11, + kDcomplexValFieldNumber = 12, + kHalfValFieldNumber = 13, + kResourceHandleValFieldNumber = 14, + kVariantValFieldNumber = 15, + kUint32ValFieldNumber = 16, + kUint64ValFieldNumber = 17, + kTensorContentFieldNumber = 4, + kFloat8ValFieldNumber = 18, + kTensorShapeFieldNumber = 2, + kDtypeFieldNumber = 1, + kVersionNumberFieldNumber = 3, + }; + // repeated float float_val = 5 [packed = true]; + int float_val_size() const; + private: + int _internal_float_val_size() const; + public: + void clear_float_val(); + private: + float _internal_float_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_float_val() const; + void _internal_add_float_val(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_float_val(); + public: + float float_val(int index) const; + void set_float_val(int index, float value); + void add_float_val(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + float_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_float_val(); + + // repeated double double_val = 6 [packed = true]; + int double_val_size() const; + private: + int _internal_double_val_size() const; + public: + void clear_double_val(); + private: + double _internal_double_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_double_val() const; + void _internal_add_double_val(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_double_val(); + public: + double double_val(int index) const; + void set_double_val(int index, double value); + void add_double_val(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + double_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_double_val(); + + // repeated int32 int_val = 7 [packed = true]; + int int_val_size() const; + private: + int _internal_int_val_size() const; + public: + void clear_int_val(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_int_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_int_val() const; + void _internal_add_int_val(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_int_val(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 int_val(int index) const; + void set_int_val(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_int_val(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + int_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_int_val(); + + // repeated bytes string_val = 8; + int string_val_size() const; + private: + int _internal_string_val_size() const; + public: + void clear_string_val(); + const std::string& string_val(int index) const; + std::string* mutable_string_val(int index); + void set_string_val(int index, const std::string& value); + void set_string_val(int index, std::string&& value); + void set_string_val(int index, const char* value); + void set_string_val(int index, const void* value, size_t size); + std::string* add_string_val(); + void add_string_val(const std::string& value); + void add_string_val(std::string&& value); + void add_string_val(const char* value); + void add_string_val(const void* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& string_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_string_val(); + private: + const std::string& _internal_string_val(int index) const; + std::string* _internal_add_string_val(); + public: + + // repeated float scomplex_val = 9 [packed = true]; + int scomplex_val_size() const; + private: + int _internal_scomplex_val_size() const; + public: + void clear_scomplex_val(); + private: + float _internal_scomplex_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_scomplex_val() const; + void _internal_add_scomplex_val(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_scomplex_val(); + public: + float scomplex_val(int index) const; + void set_scomplex_val(int index, float value); + void add_scomplex_val(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + scomplex_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_scomplex_val(); + + // repeated int64 int64_val = 10 [packed = true]; + int int64_val_size() const; + private: + int _internal_int64_val_size() const; + public: + void clear_int64_val(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_int64_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_int64_val() const; + void _internal_add_int64_val(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_int64_val(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 int64_val(int index) const; + void set_int64_val(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_int64_val(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + int64_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_int64_val(); + + // repeated bool bool_val = 11 [packed = true]; + int bool_val_size() const; + private: + int _internal_bool_val_size() const; + public: + void clear_bool_val(); + private: + bool _internal_bool_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& + _internal_bool_val() const; + void _internal_add_bool_val(bool value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* + _internal_mutable_bool_val(); + public: + bool bool_val(int index) const; + void set_bool_val(int index, bool value); + void add_bool_val(bool value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& + bool_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* + mutable_bool_val(); + + // repeated double dcomplex_val = 12 [packed = true]; + int dcomplex_val_size() const; + private: + int _internal_dcomplex_val_size() const; + public: + void clear_dcomplex_val(); + private: + double _internal_dcomplex_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_dcomplex_val() const; + void _internal_add_dcomplex_val(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_dcomplex_val(); + public: + double dcomplex_val(int index) const; + void set_dcomplex_val(int index, double value); + void add_dcomplex_val(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + dcomplex_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_dcomplex_val(); + + // repeated int32 half_val = 13 [packed = true]; + int half_val_size() const; + private: + int _internal_half_val_size() const; + public: + void clear_half_val(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_half_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_half_val() const; + void _internal_add_half_val(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_half_val(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 half_val(int index) const; + void set_half_val(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_half_val(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + half_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_half_val(); + + // repeated .tensorboard.ResourceHandleProto resource_handle_val = 14; + int resource_handle_val_size() const; + private: + int _internal_resource_handle_val_size() const; + public: + void clear_resource_handle_val(); + ::tensorboard::ResourceHandleProto* mutable_resource_handle_val(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto >* + mutable_resource_handle_val(); + private: + const ::tensorboard::ResourceHandleProto& _internal_resource_handle_val(int index) const; + ::tensorboard::ResourceHandleProto* _internal_add_resource_handle_val(); + public: + const ::tensorboard::ResourceHandleProto& resource_handle_val(int index) const; + ::tensorboard::ResourceHandleProto* add_resource_handle_val(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto >& + resource_handle_val() const; + + // repeated .tensorboard.VariantTensorDataProto variant_val = 15; + int variant_val_size() const; + private: + int _internal_variant_val_size() const; + public: + void clear_variant_val(); + ::tensorboard::VariantTensorDataProto* mutable_variant_val(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::VariantTensorDataProto >* + mutable_variant_val(); + private: + const ::tensorboard::VariantTensorDataProto& _internal_variant_val(int index) const; + ::tensorboard::VariantTensorDataProto* _internal_add_variant_val(); + public: + const ::tensorboard::VariantTensorDataProto& variant_val(int index) const; + ::tensorboard::VariantTensorDataProto* add_variant_val(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::VariantTensorDataProto >& + variant_val() const; + + // repeated uint32 uint32_val = 16 [packed = true]; + int uint32_val_size() const; + private: + int _internal_uint32_val_size() const; + public: + void clear_uint32_val(); + private: + ::PROTOBUF_NAMESPACE_ID::uint32 _internal_uint32_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >& + _internal_uint32_val() const; + void _internal_add_uint32_val(::PROTOBUF_NAMESPACE_ID::uint32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >* + _internal_mutable_uint32_val(); + public: + ::PROTOBUF_NAMESPACE_ID::uint32 uint32_val(int index) const; + void set_uint32_val(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value); + void add_uint32_val(::PROTOBUF_NAMESPACE_ID::uint32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >& + uint32_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >* + mutable_uint32_val(); + + // repeated uint64 uint64_val = 17 [packed = true]; + int uint64_val_size() const; + private: + int _internal_uint64_val_size() const; + public: + void clear_uint64_val(); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_uint64_val(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + _internal_uint64_val() const; + void _internal_add_uint64_val(::PROTOBUF_NAMESPACE_ID::uint64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + _internal_mutable_uint64_val(); + public: + ::PROTOBUF_NAMESPACE_ID::uint64 uint64_val(int index) const; + void set_uint64_val(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value); + void add_uint64_val(::PROTOBUF_NAMESPACE_ID::uint64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + uint64_val() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + mutable_uint64_val(); + + // bytes tensor_content = 4; + void clear_tensor_content(); + const std::string& tensor_content() const; + void set_tensor_content(const std::string& value); + void set_tensor_content(std::string&& value); + void set_tensor_content(const char* value); + void set_tensor_content(const void* value, size_t size); + std::string* mutable_tensor_content(); + std::string* release_tensor_content(); + void set_allocated_tensor_content(std::string* tensor_content); + private: + const std::string& _internal_tensor_content() const; + void _internal_set_tensor_content(const std::string& value); + std::string* _internal_mutable_tensor_content(); + public: + + // bytes float8_val = 18; + void clear_float8_val(); + const std::string& float8_val() const; + void set_float8_val(const std::string& value); + void set_float8_val(std::string&& value); + void set_float8_val(const char* value); + void set_float8_val(const void* value, size_t size); + std::string* mutable_float8_val(); + std::string* release_float8_val(); + void set_allocated_float8_val(std::string* float8_val); + private: + const std::string& _internal_float8_val() const; + void _internal_set_float8_val(const std::string& value); + std::string* _internal_mutable_float8_val(); + public: + + // .tensorboard.TensorShapeProto tensor_shape = 2; + bool has_tensor_shape() const; + private: + bool _internal_has_tensor_shape() const; + public: + void clear_tensor_shape(); + const ::tensorboard::TensorShapeProto& tensor_shape() const; + ::tensorboard::TensorShapeProto* release_tensor_shape(); + ::tensorboard::TensorShapeProto* mutable_tensor_shape(); + void set_allocated_tensor_shape(::tensorboard::TensorShapeProto* tensor_shape); + private: + const ::tensorboard::TensorShapeProto& _internal_tensor_shape() const; + ::tensorboard::TensorShapeProto* _internal_mutable_tensor_shape(); + public: + void unsafe_arena_set_allocated_tensor_shape( + ::tensorboard::TensorShapeProto* tensor_shape); + ::tensorboard::TensorShapeProto* unsafe_arena_release_tensor_shape(); + + // .tensorboard.DataType dtype = 1; + void clear_dtype(); + ::tensorboard::DataType dtype() const; + void set_dtype(::tensorboard::DataType value); + private: + ::tensorboard::DataType _internal_dtype() const; + void _internal_set_dtype(::tensorboard::DataType value); + public: + + // int32 version_number = 3; + void clear_version_number(); + ::PROTOBUF_NAMESPACE_ID::int32 version_number() const; + void set_version_number(::PROTOBUF_NAMESPACE_ID::int32 value); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_version_number() const; + void _internal_set_version_number(::PROTOBUF_NAMESPACE_ID::int32 value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.TensorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_val_; + mutable std::atomic _float_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > double_val_; + mutable std::atomic _double_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > int_val_; + mutable std::atomic _int_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField string_val_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > scomplex_val_; + mutable std::atomic _scomplex_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > int64_val_; + mutable std::atomic _int64_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool > bool_val_; + mutable std::atomic _bool_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > dcomplex_val_; + mutable std::atomic _dcomplex_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > half_val_; + mutable std::atomic _half_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto > resource_handle_val_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::VariantTensorDataProto > variant_val_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 > uint32_val_; + mutable std::atomic _uint32_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 > uint64_val_; + mutable std::atomic _uint64_val_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr tensor_content_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr float8_val_; + ::tensorboard::TensorShapeProto* tensor_shape_; + int dtype_; + ::PROTOBUF_NAMESPACE_ID::int32 version_number_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_tensor_2eproto; +}; +// ------------------------------------------------------------------- + +class VariantTensorDataProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.VariantTensorDataProto) */ { + public: + inline VariantTensorDataProto() : VariantTensorDataProto(nullptr) {} + virtual ~VariantTensorDataProto(); + + VariantTensorDataProto(const VariantTensorDataProto& from); + VariantTensorDataProto(VariantTensorDataProto&& from) noexcept + : VariantTensorDataProto() { + *this = ::std::move(from); + } + + inline VariantTensorDataProto& operator=(const VariantTensorDataProto& from) { + CopyFrom(from); + return *this; + } + inline VariantTensorDataProto& operator=(VariantTensorDataProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const VariantTensorDataProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const VariantTensorDataProto* internal_default_instance() { + return reinterpret_cast( + &_VariantTensorDataProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(VariantTensorDataProto& a, VariantTensorDataProto& b) { + a.Swap(&b); + } + inline void Swap(VariantTensorDataProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(VariantTensorDataProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline VariantTensorDataProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + VariantTensorDataProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const VariantTensorDataProto& from); + void MergeFrom(const VariantTensorDataProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(VariantTensorDataProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.VariantTensorDataProto"; + } + protected: + explicit VariantTensorDataProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_tensor_2eproto); + return ::descriptor_table_tensor_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kTensorsFieldNumber = 3, + kTypeNameFieldNumber = 1, + kMetadataFieldNumber = 2, + }; + // repeated .tensorboard.TensorProto tensors = 3; + int tensors_size() const; + private: + int _internal_tensors_size() const; + public: + void clear_tensors(); + ::tensorboard::TensorProto* mutable_tensors(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorProto >* + mutable_tensors(); + private: + const ::tensorboard::TensorProto& _internal_tensors(int index) const; + ::tensorboard::TensorProto* _internal_add_tensors(); + public: + const ::tensorboard::TensorProto& tensors(int index) const; + ::tensorboard::TensorProto* add_tensors(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorProto >& + tensors() const; + + // string type_name = 1; + void clear_type_name(); + const std::string& type_name() const; + void set_type_name(const std::string& value); + void set_type_name(std::string&& value); + void set_type_name(const char* value); + void set_type_name(const char* value, size_t size); + std::string* mutable_type_name(); + std::string* release_type_name(); + void set_allocated_type_name(std::string* type_name); + private: + const std::string& _internal_type_name() const; + void _internal_set_type_name(const std::string& value); + std::string* _internal_mutable_type_name(); + public: + + // bytes metadata = 2; + void clear_metadata(); + const std::string& metadata() const; + void set_metadata(const std::string& value); + void set_metadata(std::string&& value); + void set_metadata(const char* value); + void set_metadata(const void* value, size_t size); + std::string* mutable_metadata(); + std::string* release_metadata(); + void set_allocated_metadata(std::string* metadata); + private: + const std::string& _internal_metadata() const; + void _internal_set_metadata(const std::string& value); + std::string* _internal_mutable_metadata(); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.VariantTensorDataProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorProto > tensors_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr type_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr metadata_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_tensor_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// TensorProto + +// .tensorboard.DataType dtype = 1; +inline void TensorProto::clear_dtype() { + dtype_ = 0; +} +inline ::tensorboard::DataType TensorProto::_internal_dtype() const { + return static_cast< ::tensorboard::DataType >(dtype_); +} +inline ::tensorboard::DataType TensorProto::dtype() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.dtype) + return _internal_dtype(); +} +inline void TensorProto::_internal_set_dtype(::tensorboard::DataType value) { + + dtype_ = value; +} +inline void TensorProto::set_dtype(::tensorboard::DataType value) { + _internal_set_dtype(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.dtype) +} + +// .tensorboard.TensorShapeProto tensor_shape = 2; +inline bool TensorProto::_internal_has_tensor_shape() const { + return this != internal_default_instance() && tensor_shape_ != nullptr; +} +inline bool TensorProto::has_tensor_shape() const { + return _internal_has_tensor_shape(); +} +inline const ::tensorboard::TensorShapeProto& TensorProto::_internal_tensor_shape() const { + const ::tensorboard::TensorShapeProto* p = tensor_shape_; + return p != nullptr ? *p : *reinterpret_cast( + &::tensorboard::_TensorShapeProto_default_instance_); +} +inline const ::tensorboard::TensorShapeProto& TensorProto::tensor_shape() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.tensor_shape) + return _internal_tensor_shape(); +} +inline void TensorProto::unsafe_arena_set_allocated_tensor_shape( + ::tensorboard::TensorShapeProto* tensor_shape) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(tensor_shape_); + } + tensor_shape_ = tensor_shape; + if (tensor_shape) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:tensorboard.TensorProto.tensor_shape) +} +inline ::tensorboard::TensorShapeProto* TensorProto::release_tensor_shape() { + + ::tensorboard::TensorShapeProto* temp = tensor_shape_; + tensor_shape_ = nullptr; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::tensorboard::TensorShapeProto* TensorProto::unsafe_arena_release_tensor_shape() { + // @@protoc_insertion_point(field_release:tensorboard.TensorProto.tensor_shape) + + ::tensorboard::TensorShapeProto* temp = tensor_shape_; + tensor_shape_ = nullptr; + return temp; +} +inline ::tensorboard::TensorShapeProto* TensorProto::_internal_mutable_tensor_shape() { + + if (tensor_shape_ == nullptr) { + auto* p = CreateMaybeMessage<::tensorboard::TensorShapeProto>(GetArena()); + tensor_shape_ = p; + } + return tensor_shape_; +} +inline ::tensorboard::TensorShapeProto* TensorProto::mutable_tensor_shape() { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorProto.tensor_shape) + return _internal_mutable_tensor_shape(); +} +inline void TensorProto::set_allocated_tensor_shape(::tensorboard::TensorShapeProto* tensor_shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete reinterpret_cast< ::PROTOBUF_NAMESPACE_ID::MessageLite*>(tensor_shape_); + } + if (tensor_shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(tensor_shape)->GetArena(); + if (message_arena != submessage_arena) { + tensor_shape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tensor_shape, submessage_arena); + } + + } else { + + } + tensor_shape_ = tensor_shape; + // @@protoc_insertion_point(field_set_allocated:tensorboard.TensorProto.tensor_shape) +} + +// int32 version_number = 3; +inline void TensorProto::clear_version_number() { + version_number_ = 0; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_version_number() const { + return version_number_; +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::version_number() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.version_number) + return _internal_version_number(); +} +inline void TensorProto::_internal_set_version_number(::PROTOBUF_NAMESPACE_ID::int32 value) { + + version_number_ = value; +} +inline void TensorProto::set_version_number(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_set_version_number(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.version_number) +} + +// bytes tensor_content = 4; +inline void TensorProto::clear_tensor_content() { + tensor_content_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& TensorProto::tensor_content() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.tensor_content) + return _internal_tensor_content(); +} +inline void TensorProto::set_tensor_content(const std::string& value) { + _internal_set_tensor_content(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.tensor_content) +} +inline std::string* TensorProto::mutable_tensor_content() { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorProto.tensor_content) + return _internal_mutable_tensor_content(); +} +inline const std::string& TensorProto::_internal_tensor_content() const { + return tensor_content_.Get(); +} +inline void TensorProto::_internal_set_tensor_content(const std::string& value) { + + tensor_content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorProto::set_tensor_content(std::string&& value) { + + tensor_content_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.TensorProto.tensor_content) +} +inline void TensorProto::set_tensor_content(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + tensor_content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.TensorProto.tensor_content) +} +inline void TensorProto::set_tensor_content(const void* value, + size_t size) { + + tensor_content_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.TensorProto.tensor_content) +} +inline std::string* TensorProto::_internal_mutable_tensor_content() { + + return tensor_content_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorProto::release_tensor_content() { + // @@protoc_insertion_point(field_release:tensorboard.TensorProto.tensor_content) + return tensor_content_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorProto::set_allocated_tensor_content(std::string* tensor_content) { + if (tensor_content != nullptr) { + + } else { + + } + tensor_content_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), tensor_content, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.TensorProto.tensor_content) +} + +// repeated int32 half_val = 13 [packed = true]; +inline int TensorProto::_internal_half_val_size() const { + return half_val_.size(); +} +inline int TensorProto::half_val_size() const { + return _internal_half_val_size(); +} +inline void TensorProto::clear_half_val() { + half_val_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_half_val(int index) const { + return half_val_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::half_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.half_val) + return _internal_half_val(index); +} +inline void TensorProto::set_half_val(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + half_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.half_val) +} +inline void TensorProto::_internal_add_half_val(::PROTOBUF_NAMESPACE_ID::int32 value) { + half_val_.Add(value); +} +inline void TensorProto::add_half_val(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_half_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.half_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::_internal_half_val() const { + return half_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::half_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.half_val) + return _internal_half_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::_internal_mutable_half_val() { + return &half_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::mutable_half_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.half_val) + return _internal_mutable_half_val(); +} + +// repeated float float_val = 5 [packed = true]; +inline int TensorProto::_internal_float_val_size() const { + return float_val_.size(); +} +inline int TensorProto::float_val_size() const { + return _internal_float_val_size(); +} +inline void TensorProto::clear_float_val() { + float_val_.Clear(); +} +inline float TensorProto::_internal_float_val(int index) const { + return float_val_.Get(index); +} +inline float TensorProto::float_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.float_val) + return _internal_float_val(index); +} +inline void TensorProto::set_float_val(int index, float value) { + float_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.float_val) +} +inline void TensorProto::_internal_add_float_val(float value) { + float_val_.Add(value); +} +inline void TensorProto::add_float_val(float value) { + _internal_add_float_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.float_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::_internal_float_val() const { + return float_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::float_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.float_val) + return _internal_float_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::_internal_mutable_float_val() { + return &float_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::mutable_float_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.float_val) + return _internal_mutable_float_val(); +} + +// repeated double double_val = 6 [packed = true]; +inline int TensorProto::_internal_double_val_size() const { + return double_val_.size(); +} +inline int TensorProto::double_val_size() const { + return _internal_double_val_size(); +} +inline void TensorProto::clear_double_val() { + double_val_.Clear(); +} +inline double TensorProto::_internal_double_val(int index) const { + return double_val_.Get(index); +} +inline double TensorProto::double_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.double_val) + return _internal_double_val(index); +} +inline void TensorProto::set_double_val(int index, double value) { + double_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.double_val) +} +inline void TensorProto::_internal_add_double_val(double value) { + double_val_.Add(value); +} +inline void TensorProto::add_double_val(double value) { + _internal_add_double_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.double_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::_internal_double_val() const { + return double_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::double_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.double_val) + return _internal_double_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::_internal_mutable_double_val() { + return &double_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::mutable_double_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.double_val) + return _internal_mutable_double_val(); +} + +// repeated int32 int_val = 7 [packed = true]; +inline int TensorProto::_internal_int_val_size() const { + return int_val_.size(); +} +inline int TensorProto::int_val_size() const { + return _internal_int_val_size(); +} +inline void TensorProto::clear_int_val() { + int_val_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_int_val(int index) const { + return int_val_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::int_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.int_val) + return _internal_int_val(index); +} +inline void TensorProto::set_int_val(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + int_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.int_val) +} +inline void TensorProto::_internal_add_int_val(::PROTOBUF_NAMESPACE_ID::int32 value) { + int_val_.Add(value); +} +inline void TensorProto::add_int_val(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_int_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.int_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::_internal_int_val() const { + return int_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::int_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.int_val) + return _internal_int_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::_internal_mutable_int_val() { + return &int_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::mutable_int_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.int_val) + return _internal_mutable_int_val(); +} + +// repeated bytes string_val = 8; +inline int TensorProto::_internal_string_val_size() const { + return string_val_.size(); +} +inline int TensorProto::string_val_size() const { + return _internal_string_val_size(); +} +inline void TensorProto::clear_string_val() { + string_val_.Clear(); +} +inline std::string* TensorProto::add_string_val() { + // @@protoc_insertion_point(field_add_mutable:tensorboard.TensorProto.string_val) + return _internal_add_string_val(); +} +inline const std::string& TensorProto::_internal_string_val(int index) const { + return string_val_.Get(index); +} +inline const std::string& TensorProto::string_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.string_val) + return _internal_string_val(index); +} +inline std::string* TensorProto::mutable_string_val(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorProto.string_val) + return string_val_.Mutable(index); +} +inline void TensorProto::set_string_val(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.string_val) + string_val_.Mutable(index)->assign(value); +} +inline void TensorProto::set_string_val(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.string_val) + string_val_.Mutable(index)->assign(std::move(value)); +} +inline void TensorProto::set_string_val(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + string_val_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:tensorboard.TensorProto.string_val) +} +inline void TensorProto::set_string_val(int index, const void* value, size_t size) { + string_val_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:tensorboard.TensorProto.string_val) +} +inline std::string* TensorProto::_internal_add_string_val() { + return string_val_.Add(); +} +inline void TensorProto::add_string_val(const std::string& value) { + string_val_.Add()->assign(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.string_val) +} +inline void TensorProto::add_string_val(std::string&& value) { + string_val_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.string_val) +} +inline void TensorProto::add_string_val(const char* value) { + GOOGLE_DCHECK(value != nullptr); + string_val_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:tensorboard.TensorProto.string_val) +} +inline void TensorProto::add_string_val(const void* value, size_t size) { + string_val_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:tensorboard.TensorProto.string_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +TensorProto::string_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.string_val) + return string_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +TensorProto::mutable_string_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.string_val) + return &string_val_; +} + +// repeated float scomplex_val = 9 [packed = true]; +inline int TensorProto::_internal_scomplex_val_size() const { + return scomplex_val_.size(); +} +inline int TensorProto::scomplex_val_size() const { + return _internal_scomplex_val_size(); +} +inline void TensorProto::clear_scomplex_val() { + scomplex_val_.Clear(); +} +inline float TensorProto::_internal_scomplex_val(int index) const { + return scomplex_val_.Get(index); +} +inline float TensorProto::scomplex_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.scomplex_val) + return _internal_scomplex_val(index); +} +inline void TensorProto::set_scomplex_val(int index, float value) { + scomplex_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.scomplex_val) +} +inline void TensorProto::_internal_add_scomplex_val(float value) { + scomplex_val_.Add(value); +} +inline void TensorProto::add_scomplex_val(float value) { + _internal_add_scomplex_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.scomplex_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::_internal_scomplex_val() const { + return scomplex_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::scomplex_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.scomplex_val) + return _internal_scomplex_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::_internal_mutable_scomplex_val() { + return &scomplex_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::mutable_scomplex_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.scomplex_val) + return _internal_mutable_scomplex_val(); +} + +// repeated int64 int64_val = 10 [packed = true]; +inline int TensorProto::_internal_int64_val_size() const { + return int64_val_.size(); +} +inline int TensorProto::int64_val_size() const { + return _internal_int64_val_size(); +} +inline void TensorProto::clear_int64_val() { + int64_val_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_int64_val(int index) const { + return int64_val_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::int64_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.int64_val) + return _internal_int64_val(index); +} +inline void TensorProto::set_int64_val(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.int64_val) +} +inline void TensorProto::_internal_add_int64_val(::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_val_.Add(value); +} +inline void TensorProto::add_int64_val(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_int64_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.int64_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_int64_val() const { + return int64_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::int64_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.int64_val) + return _internal_int64_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_int64_val() { + return &int64_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_int64_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.int64_val) + return _internal_mutable_int64_val(); +} + +// repeated bool bool_val = 11 [packed = true]; +inline int TensorProto::_internal_bool_val_size() const { + return bool_val_.size(); +} +inline int TensorProto::bool_val_size() const { + return _internal_bool_val_size(); +} +inline void TensorProto::clear_bool_val() { + bool_val_.Clear(); +} +inline bool TensorProto::_internal_bool_val(int index) const { + return bool_val_.Get(index); +} +inline bool TensorProto::bool_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.bool_val) + return _internal_bool_val(index); +} +inline void TensorProto::set_bool_val(int index, bool value) { + bool_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.bool_val) +} +inline void TensorProto::_internal_add_bool_val(bool value) { + bool_val_.Add(value); +} +inline void TensorProto::add_bool_val(bool value) { + _internal_add_bool_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.bool_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& +TensorProto::_internal_bool_val() const { + return bool_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >& +TensorProto::bool_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.bool_val) + return _internal_bool_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* +TensorProto::_internal_mutable_bool_val() { + return &bool_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< bool >* +TensorProto::mutable_bool_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.bool_val) + return _internal_mutable_bool_val(); +} + +// repeated double dcomplex_val = 12 [packed = true]; +inline int TensorProto::_internal_dcomplex_val_size() const { + return dcomplex_val_.size(); +} +inline int TensorProto::dcomplex_val_size() const { + return _internal_dcomplex_val_size(); +} +inline void TensorProto::clear_dcomplex_val() { + dcomplex_val_.Clear(); +} +inline double TensorProto::_internal_dcomplex_val(int index) const { + return dcomplex_val_.Get(index); +} +inline double TensorProto::dcomplex_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.dcomplex_val) + return _internal_dcomplex_val(index); +} +inline void TensorProto::set_dcomplex_val(int index, double value) { + dcomplex_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.dcomplex_val) +} +inline void TensorProto::_internal_add_dcomplex_val(double value) { + dcomplex_val_.Add(value); +} +inline void TensorProto::add_dcomplex_val(double value) { + _internal_add_dcomplex_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.dcomplex_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::_internal_dcomplex_val() const { + return dcomplex_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::dcomplex_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.dcomplex_val) + return _internal_dcomplex_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::_internal_mutable_dcomplex_val() { + return &dcomplex_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::mutable_dcomplex_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.dcomplex_val) + return _internal_mutable_dcomplex_val(); +} + +// repeated .tensorboard.ResourceHandleProto resource_handle_val = 14; +inline int TensorProto::_internal_resource_handle_val_size() const { + return resource_handle_val_.size(); +} +inline int TensorProto::resource_handle_val_size() const { + return _internal_resource_handle_val_size(); +} +inline ::tensorboard::ResourceHandleProto* TensorProto::mutable_resource_handle_val(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorProto.resource_handle_val) + return resource_handle_val_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto >* +TensorProto::mutable_resource_handle_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.resource_handle_val) + return &resource_handle_val_; +} +inline const ::tensorboard::ResourceHandleProto& TensorProto::_internal_resource_handle_val(int index) const { + return resource_handle_val_.Get(index); +} +inline const ::tensorboard::ResourceHandleProto& TensorProto::resource_handle_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.resource_handle_val) + return _internal_resource_handle_val(index); +} +inline ::tensorboard::ResourceHandleProto* TensorProto::_internal_add_resource_handle_val() { + return resource_handle_val_.Add(); +} +inline ::tensorboard::ResourceHandleProto* TensorProto::add_resource_handle_val() { + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.resource_handle_val) + return _internal_add_resource_handle_val(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::ResourceHandleProto >& +TensorProto::resource_handle_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.resource_handle_val) + return resource_handle_val_; +} + +// repeated .tensorboard.VariantTensorDataProto variant_val = 15; +inline int TensorProto::_internal_variant_val_size() const { + return variant_val_.size(); +} +inline int TensorProto::variant_val_size() const { + return _internal_variant_val_size(); +} +inline void TensorProto::clear_variant_val() { + variant_val_.Clear(); +} +inline ::tensorboard::VariantTensorDataProto* TensorProto::mutable_variant_val(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorProto.variant_val) + return variant_val_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::VariantTensorDataProto >* +TensorProto::mutable_variant_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.variant_val) + return &variant_val_; +} +inline const ::tensorboard::VariantTensorDataProto& TensorProto::_internal_variant_val(int index) const { + return variant_val_.Get(index); +} +inline const ::tensorboard::VariantTensorDataProto& TensorProto::variant_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.variant_val) + return _internal_variant_val(index); +} +inline ::tensorboard::VariantTensorDataProto* TensorProto::_internal_add_variant_val() { + return variant_val_.Add(); +} +inline ::tensorboard::VariantTensorDataProto* TensorProto::add_variant_val() { + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.variant_val) + return _internal_add_variant_val(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::VariantTensorDataProto >& +TensorProto::variant_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.variant_val) + return variant_val_; +} + +// repeated uint32 uint32_val = 16 [packed = true]; +inline int TensorProto::_internal_uint32_val_size() const { + return uint32_val_.size(); +} +inline int TensorProto::uint32_val_size() const { + return _internal_uint32_val_size(); +} +inline void TensorProto::clear_uint32_val() { + uint32_val_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::uint32 TensorProto::_internal_uint32_val(int index) const { + return uint32_val_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::uint32 TensorProto::uint32_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.uint32_val) + return _internal_uint32_val(index); +} +inline void TensorProto::set_uint32_val(int index, ::PROTOBUF_NAMESPACE_ID::uint32 value) { + uint32_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.uint32_val) +} +inline void TensorProto::_internal_add_uint32_val(::PROTOBUF_NAMESPACE_ID::uint32 value) { + uint32_val_.Add(value); +} +inline void TensorProto::add_uint32_val(::PROTOBUF_NAMESPACE_ID::uint32 value) { + _internal_add_uint32_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.uint32_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >& +TensorProto::_internal_uint32_val() const { + return uint32_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >& +TensorProto::uint32_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.uint32_val) + return _internal_uint32_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >* +TensorProto::_internal_mutable_uint32_val() { + return &uint32_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint32 >* +TensorProto::mutable_uint32_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.uint32_val) + return _internal_mutable_uint32_val(); +} + +// repeated uint64 uint64_val = 17 [packed = true]; +inline int TensorProto::_internal_uint64_val_size() const { + return uint64_val_.size(); +} +inline int TensorProto::uint64_val_size() const { + return _internal_uint64_val_size(); +} +inline void TensorProto::clear_uint64_val() { + uint64_val_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::_internal_uint64_val(int index) const { + return uint64_val_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::uint64_val(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.uint64_val) + return _internal_uint64_val(index); +} +inline void TensorProto::set_uint64_val(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_val_.Set(index, value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.uint64_val) +} +inline void TensorProto::_internal_add_uint64_val(::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_val_.Add(value); +} +inline void TensorProto::add_uint64_val(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_add_uint64_val(value); + // @@protoc_insertion_point(field_add:tensorboard.TensorProto.uint64_val) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::_internal_uint64_val() const { + return uint64_val_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::uint64_val() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorProto.uint64_val) + return _internal_uint64_val(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::_internal_mutable_uint64_val() { + return &uint64_val_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::mutable_uint64_val() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorProto.uint64_val) + return _internal_mutable_uint64_val(); +} + +// bytes float8_val = 18; +inline void TensorProto::clear_float8_val() { + float8_val_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& TensorProto::float8_val() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorProto.float8_val) + return _internal_float8_val(); +} +inline void TensorProto::set_float8_val(const std::string& value) { + _internal_set_float8_val(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorProto.float8_val) +} +inline std::string* TensorProto::mutable_float8_val() { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorProto.float8_val) + return _internal_mutable_float8_val(); +} +inline const std::string& TensorProto::_internal_float8_val() const { + return float8_val_.Get(); +} +inline void TensorProto::_internal_set_float8_val(const std::string& value) { + + float8_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorProto::set_float8_val(std::string&& value) { + + float8_val_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.TensorProto.float8_val) +} +inline void TensorProto::set_float8_val(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + float8_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.TensorProto.float8_val) +} +inline void TensorProto::set_float8_val(const void* value, + size_t size) { + + float8_val_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.TensorProto.float8_val) +} +inline std::string* TensorProto::_internal_mutable_float8_val() { + + return float8_val_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorProto::release_float8_val() { + // @@protoc_insertion_point(field_release:tensorboard.TensorProto.float8_val) + return float8_val_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorProto::set_allocated_float8_val(std::string* float8_val) { + if (float8_val != nullptr) { + + } else { + + } + float8_val_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), float8_val, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.TensorProto.float8_val) +} + +// ------------------------------------------------------------------- + +// VariantTensorDataProto + +// string type_name = 1; +inline void VariantTensorDataProto::clear_type_name() { + type_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& VariantTensorDataProto::type_name() const { + // @@protoc_insertion_point(field_get:tensorboard.VariantTensorDataProto.type_name) + return _internal_type_name(); +} +inline void VariantTensorDataProto::set_type_name(const std::string& value) { + _internal_set_type_name(value); + // @@protoc_insertion_point(field_set:tensorboard.VariantTensorDataProto.type_name) +} +inline std::string* VariantTensorDataProto::mutable_type_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.VariantTensorDataProto.type_name) + return _internal_mutable_type_name(); +} +inline const std::string& VariantTensorDataProto::_internal_type_name() const { + return type_name_.Get(); +} +inline void VariantTensorDataProto::_internal_set_type_name(const std::string& value) { + + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void VariantTensorDataProto::set_type_name(std::string&& value) { + + type_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.VariantTensorDataProto.type_name) +} +inline void VariantTensorDataProto::set_type_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.VariantTensorDataProto.type_name) +} +inline void VariantTensorDataProto::set_type_name(const char* value, + size_t size) { + + type_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.VariantTensorDataProto.type_name) +} +inline std::string* VariantTensorDataProto::_internal_mutable_type_name() { + + return type_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* VariantTensorDataProto::release_type_name() { + // @@protoc_insertion_point(field_release:tensorboard.VariantTensorDataProto.type_name) + return type_name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void VariantTensorDataProto::set_allocated_type_name(std::string* type_name) { + if (type_name != nullptr) { + + } else { + + } + type_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), type_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.VariantTensorDataProto.type_name) +} + +// bytes metadata = 2; +inline void VariantTensorDataProto::clear_metadata() { + metadata_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& VariantTensorDataProto::metadata() const { + // @@protoc_insertion_point(field_get:tensorboard.VariantTensorDataProto.metadata) + return _internal_metadata(); +} +inline void VariantTensorDataProto::set_metadata(const std::string& value) { + _internal_set_metadata(value); + // @@protoc_insertion_point(field_set:tensorboard.VariantTensorDataProto.metadata) +} +inline std::string* VariantTensorDataProto::mutable_metadata() { + // @@protoc_insertion_point(field_mutable:tensorboard.VariantTensorDataProto.metadata) + return _internal_mutable_metadata(); +} +inline const std::string& VariantTensorDataProto::_internal_metadata() const { + return metadata_.Get(); +} +inline void VariantTensorDataProto::_internal_set_metadata(const std::string& value) { + + metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void VariantTensorDataProto::set_metadata(std::string&& value) { + + metadata_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.VariantTensorDataProto.metadata) +} +inline void VariantTensorDataProto::set_metadata(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.VariantTensorDataProto.metadata) +} +inline void VariantTensorDataProto::set_metadata(const void* value, + size_t size) { + + metadata_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.VariantTensorDataProto.metadata) +} +inline std::string* VariantTensorDataProto::_internal_mutable_metadata() { + + return metadata_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* VariantTensorDataProto::release_metadata() { + // @@protoc_insertion_point(field_release:tensorboard.VariantTensorDataProto.metadata) + return metadata_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void VariantTensorDataProto::set_allocated_metadata(std::string* metadata) { + if (metadata != nullptr) { + + } else { + + } + metadata_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), metadata, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.VariantTensorDataProto.metadata) +} + +// repeated .tensorboard.TensorProto tensors = 3; +inline int VariantTensorDataProto::_internal_tensors_size() const { + return tensors_.size(); +} +inline int VariantTensorDataProto::tensors_size() const { + return _internal_tensors_size(); +} +inline void VariantTensorDataProto::clear_tensors() { + tensors_.Clear(); +} +inline ::tensorboard::TensorProto* VariantTensorDataProto::mutable_tensors(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.VariantTensorDataProto.tensors) + return tensors_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorProto >* +VariantTensorDataProto::mutable_tensors() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.VariantTensorDataProto.tensors) + return &tensors_; +} +inline const ::tensorboard::TensorProto& VariantTensorDataProto::_internal_tensors(int index) const { + return tensors_.Get(index); +} +inline const ::tensorboard::TensorProto& VariantTensorDataProto::tensors(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.VariantTensorDataProto.tensors) + return _internal_tensors(index); +} +inline ::tensorboard::TensorProto* VariantTensorDataProto::_internal_add_tensors() { + return tensors_.Add(); +} +inline ::tensorboard::TensorProto* VariantTensorDataProto::add_tensors() { + // @@protoc_insertion_point(field_add:tensorboard.VariantTensorDataProto.tensors) + return _internal_add_tensors(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorProto >& +VariantTensorDataProto::tensors() const { + // @@protoc_insertion_point(field_list:tensorboard.VariantTensorDataProto.tensors) + return tensors_; +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_tensor_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/tensor.proto b/plugins/mindstudio-insight-plugins/proto/tensor.proto new file mode 100644 index 0000000000000000000000000000000000000000..251ac2b6b7fa85969c193b657080e0440ac7a700 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/tensor.proto @@ -0,0 +1,101 @@ +syntax = "proto3"; + +package tensorboard; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + // Data type of the tensor. + DataType dtype = 1; + + // Shape of the tensor. TODO(touts): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_UINT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; + + // DT_FLOAT8_*, use variable-sized set of bytes + // (i.e. the equivalent of repeated uint8, if such a thing existed). + bytes float8_val = 18; +} + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/plugins/mindstudio-insight-plugins/proto/tensor_shape.pb.cc b/plugins/mindstudio-insight-plugins/proto/tensor_shape.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d61428efe34a66ee138694e50aea09e7597126f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/tensor_shape.pb.cc @@ -0,0 +1,591 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor_shape.proto + +#include "tensor_shape.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_tensor_5fshape_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto; +namespace tensorboard { +class TensorShapeProto_DimDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorShapeProto_Dim_default_instance_; +class TensorShapeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorShapeProto_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_TensorShapeProto_tensor_5fshape_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_TensorShapeProto_default_instance_; + new (ptr) ::tensorboard::TensorShapeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::TensorShapeProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_tensor_5fshape_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TensorShapeProto_tensor_5fshape_2eproto}, { + &scc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto.base,}}; + +static void InitDefaultsscc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_TensorShapeProto_Dim_default_instance_; + new (ptr) ::tensorboard::TensorShapeProto_Dim(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::TensorShapeProto_Dim::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto}, {}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_tensor_5fshape_2eproto[2]; +static constexpr ::PROTOBUF_NAMESPACE_ID::EnumDescriptor const** file_level_enum_descriptors_tensor_5fshape_2eproto = nullptr; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_tensor_5fshape_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_tensor_5fshape_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorShapeProto_Dim, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorShapeProto_Dim, size_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorShapeProto_Dim, name_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorShapeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorShapeProto, dim_), + PROTOBUF_FIELD_OFFSET(::tensorboard::TensorShapeProto, unknown_rank_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::TensorShapeProto_Dim)}, + { 7, -1, sizeof(::tensorboard::TensorShapeProto)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_TensorShapeProto_Dim_default_instance_), + reinterpret_cast(&::tensorboard::_TensorShapeProto_default_instance_), +}; + +const char descriptor_table_protodef_tensor_5fshape_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\022tensor_shape.proto\022\013tensorboard\"{\n\020Ten" + "sorShapeProto\022.\n\003dim\030\002 \003(\0132!.tensorboard" + ".TensorShapeProto.Dim\022\024\n\014unknown_rank\030\003 " + "\001(\010\032!\n\003Dim\022\014\n\004size\030\001 \001(\003\022\014\n\004name\030\002 \001(\tB\207" + "\001\n\030org.tensorflow.frameworkB\021TensorShape" + "ProtosP\001ZSgithub.com/tensorflow/tensorfl" + "ow/tensorflow/go/core/framework/tensor_s" + "hape_go_proto\370\001\001b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_tensor_5fshape_2eproto_deps[1] = { +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_tensor_5fshape_2eproto_sccs[2] = { + &scc_info_TensorShapeProto_tensor_5fshape_2eproto.base, + &scc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_tensor_5fshape_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_tensor_5fshape_2eproto = { + false, false, descriptor_table_protodef_tensor_5fshape_2eproto, "tensor_shape.proto", 304, + &descriptor_table_tensor_5fshape_2eproto_once, descriptor_table_tensor_5fshape_2eproto_sccs, descriptor_table_tensor_5fshape_2eproto_deps, 2, 0, + schemas, file_default_instances, TableStruct_tensor_5fshape_2eproto::offsets, + file_level_metadata_tensor_5fshape_2eproto, 2, file_level_enum_descriptors_tensor_5fshape_2eproto, file_level_service_descriptors_tensor_5fshape_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_tensor_5fshape_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_tensor_5fshape_2eproto)), true); +namespace tensorboard { + +// =================================================================== + +void TensorShapeProto_Dim::InitAsDefaultInstance() { +} +class TensorShapeProto_Dim::_Internal { + public: +}; + +TensorShapeProto_Dim::TensorShapeProto_Dim(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.TensorShapeProto.Dim) +} +TensorShapeProto_Dim::TensorShapeProto_Dim(const TensorShapeProto_Dim& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (!from._internal_name().empty()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + size_ = from.size_; + // @@protoc_insertion_point(copy_constructor:tensorboard.TensorShapeProto.Dim) +} + +void TensorShapeProto_Dim::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + size_ = PROTOBUF_LONGLONG(0); +} + +TensorShapeProto_Dim::~TensorShapeProto_Dim() { + // @@protoc_insertion_point(destructor:tensorboard.TensorShapeProto.Dim) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorShapeProto_Dim::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void TensorShapeProto_Dim::ArenaDtor(void* object) { + TensorShapeProto_Dim* _this = reinterpret_cast< TensorShapeProto_Dim* >(object); + (void)_this; +} +void TensorShapeProto_Dim::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorShapeProto_Dim::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorShapeProto_Dim& TensorShapeProto_Dim::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorShapeProto_Dim_tensor_5fshape_2eproto.base); + return *internal_default_instance(); +} + + +void TensorShapeProto_Dim::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.TensorShapeProto.Dim) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + size_ = PROTOBUF_LONGLONG(0); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorShapeProto_Dim::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // int64 size = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + size_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "tensorboard.TensorShapeProto.Dim.name")); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorShapeProto_Dim::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.TensorShapeProto.Dim) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // int64 size = 1; + if (this->size() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_size(), target); + } + + // string name = 2; + if (this->name().size() > 0) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "tensorboard.TensorShapeProto.Dim.name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_name(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.TensorShapeProto.Dim) + return target; +} + +size_t TensorShapeProto_Dim::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.TensorShapeProto.Dim) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string name = 2; + if (this->name().size() > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // int64 size = 1; + if (this->size() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_size()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorShapeProto_Dim::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.TensorShapeProto.Dim) + GOOGLE_DCHECK_NE(&from, this); + const TensorShapeProto_Dim* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.TensorShapeProto.Dim) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.TensorShapeProto.Dim) + MergeFrom(*source); + } +} + +void TensorShapeProto_Dim::MergeFrom(const TensorShapeProto_Dim& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.TensorShapeProto.Dim) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.name().size() > 0) { + _internal_set_name(from._internal_name()); + } + if (from.size() != 0) { + _internal_set_size(from._internal_size()); + } +} + +void TensorShapeProto_Dim::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.TensorShapeProto.Dim) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorShapeProto_Dim::CopyFrom(const TensorShapeProto_Dim& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.TensorShapeProto.Dim) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorShapeProto_Dim::IsInitialized() const { + return true; +} + +void TensorShapeProto_Dim::InternalSwap(TensorShapeProto_Dim* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(size_, other->size_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorShapeProto_Dim::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorShapeProto::InitAsDefaultInstance() { +} +class TensorShapeProto::_Internal { + public: +}; + +TensorShapeProto::TensorShapeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + dim_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.TensorShapeProto) +} +TensorShapeProto::TensorShapeProto(const TensorShapeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + dim_(from.dim_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + unknown_rank_ = from.unknown_rank_; + // @@protoc_insertion_point(copy_constructor:tensorboard.TensorShapeProto) +} + +void TensorShapeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorShapeProto_tensor_5fshape_2eproto.base); + unknown_rank_ = false; +} + +TensorShapeProto::~TensorShapeProto() { + // @@protoc_insertion_point(destructor:tensorboard.TensorShapeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorShapeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void TensorShapeProto::ArenaDtor(void* object) { + TensorShapeProto* _this = reinterpret_cast< TensorShapeProto* >(object); + (void)_this; +} +void TensorShapeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorShapeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorShapeProto& TensorShapeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorShapeProto_tensor_5fshape_2eproto.base); + return *internal_default_instance(); +} + + +void TensorShapeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + dim_.Clear(); + unknown_rank_ = false; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorShapeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .tensorboard.TensorShapeProto.Dim dim = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_dim(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<18>(ptr)); + } else goto handle_unusual; + continue; + // bool unknown_rank = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + unknown_rank_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorShapeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .tensorboard.TensorShapeProto.Dim dim = 2; + for (unsigned int i = 0, + n = static_cast(this->_internal_dim_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(2, this->_internal_dim(i), target, stream); + } + + // bool unknown_rank = 3; + if (this->unknown_rank() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteBoolToArray(3, this->_internal_unknown_rank(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.TensorShapeProto) + return target; +} + +size_t TensorShapeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.TensorShapeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .tensorboard.TensorShapeProto.Dim dim = 2; + total_size += 1UL * this->_internal_dim_size(); + for (const auto& msg : this->dim_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // bool unknown_rank = 3; + if (this->unknown_rank() != 0) { + total_size += 1 + 1; + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorShapeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.TensorShapeProto) + GOOGLE_DCHECK_NE(&from, this); + const TensorShapeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.TensorShapeProto) + MergeFrom(*source); + } +} + +void TensorShapeProto::MergeFrom(const TensorShapeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.TensorShapeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + dim_.MergeFrom(from.dim_); + if (from.unknown_rank() != 0) { + _internal_set_unknown_rank(from._internal_unknown_rank()); + } +} + +void TensorShapeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.TensorShapeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorShapeProto::CopyFrom(const TensorShapeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.TensorShapeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorShapeProto::IsInitialized() const { + return true; +} + +void TensorShapeProto::InternalSwap(TensorShapeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + dim_.InternalSwap(&other->dim_); + swap(unknown_rank_, other->unknown_rank_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorShapeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::TensorShapeProto_Dim* Arena::CreateMaybeMessage< ::tensorboard::TensorShapeProto_Dim >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::TensorShapeProto_Dim >(arena); +} +template<> PROTOBUF_NOINLINE ::tensorboard::TensorShapeProto* Arena::CreateMaybeMessage< ::tensorboard::TensorShapeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::TensorShapeProto >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/tensor_shape.pb.h b/plugins/mindstudio-insight-plugins/proto/tensor_shape.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..3d971e0d9156979dff3418062ee0469a325cd670 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/tensor_shape.pb.h @@ -0,0 +1,554 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor_shape.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_tensor_5fshape_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_tensor_5fshape_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_tensor_5fshape_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_tensor_5fshape_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[2] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_tensor_5fshape_2eproto; +namespace tensorboard { +class TensorShapeProto; +class TensorShapeProtoDefaultTypeInternal; +extern TensorShapeProtoDefaultTypeInternal _TensorShapeProto_default_instance_; +class TensorShapeProto_Dim; +class TensorShapeProto_DimDefaultTypeInternal; +extern TensorShapeProto_DimDefaultTypeInternal _TensorShapeProto_Dim_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::TensorShapeProto* Arena::CreateMaybeMessage<::tensorboard::TensorShapeProto>(Arena*); +template<> ::tensorboard::TensorShapeProto_Dim* Arena::CreateMaybeMessage<::tensorboard::TensorShapeProto_Dim>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +// =================================================================== + +class TensorShapeProto_Dim PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.TensorShapeProto.Dim) */ { + public: + inline TensorShapeProto_Dim() : TensorShapeProto_Dim(nullptr) {} + virtual ~TensorShapeProto_Dim(); + + TensorShapeProto_Dim(const TensorShapeProto_Dim& from); + TensorShapeProto_Dim(TensorShapeProto_Dim&& from) noexcept + : TensorShapeProto_Dim() { + *this = ::std::move(from); + } + + inline TensorShapeProto_Dim& operator=(const TensorShapeProto_Dim& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto_Dim& operator=(TensorShapeProto_Dim&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto_Dim& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto_Dim* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_Dim_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(TensorShapeProto_Dim& a, TensorShapeProto_Dim& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto_Dim* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorShapeProto_Dim* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto_Dim* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto_Dim* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto_Dim& from); + void MergeFrom(const TensorShapeProto_Dim& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto_Dim* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.TensorShapeProto.Dim"; + } + protected: + explicit TensorShapeProto_Dim(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_tensor_5fshape_2eproto); + return ::descriptor_table_tensor_5fshape_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 2, + kSizeFieldNumber = 1, + }; + // string name = 2; + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // int64 size = 1; + void clear_size(); + ::PROTOBUF_NAMESPACE_ID::int64 size() const; + void set_size(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_size() const; + void _internal_set_size(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto.Dim) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::int64 size_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_tensor_5fshape_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.TensorShapeProto) */ { + public: + inline TensorShapeProto() : TensorShapeProto(nullptr) {} + virtual ~TensorShapeProto(); + + TensorShapeProto(const TensorShapeProto& from); + TensorShapeProto(TensorShapeProto&& from) noexcept + : TensorShapeProto() { + *this = ::std::move(from); + } + + inline TensorShapeProto& operator=(const TensorShapeProto& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto& operator=(TensorShapeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(TensorShapeProto& a, TensorShapeProto& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorShapeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto& from); + void MergeFrom(const TensorShapeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.TensorShapeProto"; + } + protected: + explicit TensorShapeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_tensor_5fshape_2eproto); + return ::descriptor_table_tensor_5fshape_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TensorShapeProto_Dim Dim; + + // accessors ------------------------------------------------------- + + enum : int { + kDimFieldNumber = 2, + kUnknownRankFieldNumber = 3, + }; + // repeated .tensorboard.TensorShapeProto.Dim dim = 2; + int dim_size() const; + private: + int _internal_dim_size() const; + public: + void clear_dim(); + ::tensorboard::TensorShapeProto_Dim* mutable_dim(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorShapeProto_Dim >* + mutable_dim(); + private: + const ::tensorboard::TensorShapeProto_Dim& _internal_dim(int index) const; + ::tensorboard::TensorShapeProto_Dim* _internal_add_dim(); + public: + const ::tensorboard::TensorShapeProto_Dim& dim(int index) const; + ::tensorboard::TensorShapeProto_Dim* add_dim(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorShapeProto_Dim >& + dim() const; + + // bool unknown_rank = 3; + void clear_unknown_rank(); + bool unknown_rank() const; + void set_unknown_rank(bool value); + private: + bool _internal_unknown_rank() const; + void _internal_set_unknown_rank(bool value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.TensorShapeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorShapeProto_Dim > dim_; + bool unknown_rank_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_tensor_5fshape_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// TensorShapeProto_Dim + +// int64 size = 1; +inline void TensorShapeProto_Dim::clear_size() { + size_ = PROTOBUF_LONGLONG(0); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dim::_internal_size() const { + return size_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dim::size() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorShapeProto.Dim.size) + return _internal_size(); +} +inline void TensorShapeProto_Dim::_internal_set_size(::PROTOBUF_NAMESPACE_ID::int64 value) { + + size_ = value; +} +inline void TensorShapeProto_Dim::set_size(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_size(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorShapeProto.Dim.size) +} + +// string name = 2; +inline void TensorShapeProto_Dim::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline const std::string& TensorShapeProto_Dim::name() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorShapeProto.Dim.name) + return _internal_name(); +} +inline void TensorShapeProto_Dim::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorShapeProto.Dim.name) +} +inline std::string* TensorShapeProto_Dim::mutable_name() { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorShapeProto.Dim.name) + return _internal_mutable_name(); +} +inline const std::string& TensorShapeProto_Dim::_internal_name() const { + return name_.Get(); +} +inline void TensorShapeProto_Dim::_internal_set_name(const std::string& value) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorShapeProto_Dim::set_name(std::string&& value) { + + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:tensorboard.TensorShapeProto.Dim.name) +} +inline void TensorShapeProto_Dim::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:tensorboard.TensorShapeProto.Dim.name) +} +inline void TensorShapeProto_Dim::set_name(const char* value, + size_t size) { + + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:tensorboard.TensorShapeProto.Dim.name) +} +inline std::string* TensorShapeProto_Dim::_internal_mutable_name() { + + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorShapeProto_Dim::release_name() { + // @@protoc_insertion_point(field_release:tensorboard.TensorShapeProto.Dim.name) + return name_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorShapeProto_Dim::set_allocated_name(std::string* name) { + if (name != nullptr) { + + } else { + + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:tensorboard.TensorShapeProto.Dim.name) +} + +// ------------------------------------------------------------------- + +// TensorShapeProto + +// repeated .tensorboard.TensorShapeProto.Dim dim = 2; +inline int TensorShapeProto::_internal_dim_size() const { + return dim_.size(); +} +inline int TensorShapeProto::dim_size() const { + return _internal_dim_size(); +} +inline void TensorShapeProto::clear_dim() { + dim_.Clear(); +} +inline ::tensorboard::TensorShapeProto_Dim* TensorShapeProto::mutable_dim(int index) { + // @@protoc_insertion_point(field_mutable:tensorboard.TensorShapeProto.dim) + return dim_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorShapeProto_Dim >* +TensorShapeProto::mutable_dim() { + // @@protoc_insertion_point(field_mutable_list:tensorboard.TensorShapeProto.dim) + return &dim_; +} +inline const ::tensorboard::TensorShapeProto_Dim& TensorShapeProto::_internal_dim(int index) const { + return dim_.Get(index); +} +inline const ::tensorboard::TensorShapeProto_Dim& TensorShapeProto::dim(int index) const { + // @@protoc_insertion_point(field_get:tensorboard.TensorShapeProto.dim) + return _internal_dim(index); +} +inline ::tensorboard::TensorShapeProto_Dim* TensorShapeProto::_internal_add_dim() { + return dim_.Add(); +} +inline ::tensorboard::TensorShapeProto_Dim* TensorShapeProto::add_dim() { + // @@protoc_insertion_point(field_add:tensorboard.TensorShapeProto.dim) + return _internal_add_dim(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::tensorboard::TensorShapeProto_Dim >& +TensorShapeProto::dim() const { + // @@protoc_insertion_point(field_list:tensorboard.TensorShapeProto.dim) + return dim_; +} + +// bool unknown_rank = 3; +inline void TensorShapeProto::clear_unknown_rank() { + unknown_rank_ = false; +} +inline bool TensorShapeProto::_internal_unknown_rank() const { + return unknown_rank_; +} +inline bool TensorShapeProto::unknown_rank() const { + // @@protoc_insertion_point(field_get:tensorboard.TensorShapeProto.unknown_rank) + return _internal_unknown_rank(); +} +inline void TensorShapeProto::_internal_set_unknown_rank(bool value) { + + unknown_rank_ = value; +} +inline void TensorShapeProto::set_unknown_rank(bool value) { + _internal_set_unknown_rank(value); + // @@protoc_insertion_point(field_set:tensorboard.TensorShapeProto.unknown_rank) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_tensor_5fshape_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/tensor_shape.proto b/plugins/mindstudio-insight-plugins/proto/tensor_shape.proto new file mode 100644 index 0000000000000000000000000000000000000000..9da8a082b5129973676357c5ff7f82b11bd65c94 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/tensor_shape.proto @@ -0,0 +1,47 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; + +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"; + +package tensorboard; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/plugins/mindstudio-insight-plugins/proto/types.pb.cc b/plugins/mindstudio-insight-plugins/proto/types.pb.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d40a8cf81d7d9f81c7226cc9215cc62e67845bf --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/types.pb.cc @@ -0,0 +1,380 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: types.proto + +#include "types.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +namespace tensorboard { +class SerializedDTypeDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _SerializedDType_default_instance_; +} // namespace tensorboard +static void InitDefaultsscc_info_SerializedDType_types_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::tensorboard::_SerializedDType_default_instance_; + new (ptr) ::tensorboard::SerializedDType(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::tensorboard::SerializedDType::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SerializedDType_types_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_SerializedDType_types_2eproto}, {}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_types_2eproto[1]; +static const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* file_level_enum_descriptors_types_2eproto[1]; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_types_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_types_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SerializedDType, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::tensorboard::SerializedDType, datatype_), +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, -1, sizeof(::tensorboard::SerializedDType)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::tensorboard::_SerializedDType_default_instance_), +}; + +const char descriptor_table_protodef_types_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\013types.proto\022\013tensorboard\":\n\017Serialized" + "DType\022\'\n\010datatype\030\001 \001(\0162\025.tensorboard.Da" + "taType*\306\007\n\010DataType\022\016\n\nDT_INVALID\020\000\022\014\n\010D" + "T_FLOAT\020\001\022\r\n\tDT_DOUBLE\020\002\022\014\n\010DT_INT32\020\003\022\014" + "\n\010DT_UINT8\020\004\022\014\n\010DT_INT16\020\005\022\013\n\007DT_INT8\020\006\022" + "\r\n\tDT_STRING\020\007\022\020\n\014DT_COMPLEX64\020\010\022\014\n\010DT_I" + "NT64\020\t\022\013\n\007DT_BOOL\020\n\022\014\n\010DT_QINT8\020\013\022\r\n\tDT_" + "QUINT8\020\014\022\r\n\tDT_QINT32\020\r\022\017\n\013DT_BFLOAT16\020\016" + "\022\r\n\tDT_QINT16\020\017\022\016\n\nDT_QUINT16\020\020\022\r\n\tDT_UI" + "NT16\020\021\022\021\n\rDT_COMPLEX128\020\022\022\013\n\007DT_HALF\020\023\022\017" + "\n\013DT_RESOURCE\020\024\022\016\n\nDT_VARIANT\020\025\022\r\n\tDT_UI" + "NT32\020\026\022\r\n\tDT_UINT64\020\027\022\022\n\016DT_FLOAT8_E5M2\020" + "\030\022\024\n\020DT_FLOAT8_E4M3FN\020\031\022\013\n\007DT_INT4\020\035\022\014\n\010" + "DT_UINT4\020\036\022\020\n\014DT_FLOAT_REF\020e\022\021\n\rDT_DOUBL" + "E_REF\020f\022\020\n\014DT_INT32_REF\020g\022\020\n\014DT_UINT8_RE" + "F\020h\022\020\n\014DT_INT16_REF\020i\022\017\n\013DT_INT8_REF\020j\022\021" + "\n\rDT_STRING_REF\020k\022\024\n\020DT_COMPLEX64_REF\020l\022" + "\020\n\014DT_INT64_REF\020m\022\017\n\013DT_BOOL_REF\020n\022\020\n\014DT" + "_QINT8_REF\020o\022\021\n\rDT_QUINT8_REF\020p\022\021\n\rDT_QI" + "NT32_REF\020q\022\023\n\017DT_BFLOAT16_REF\020r\022\021\n\rDT_QI" + "NT16_REF\020s\022\022\n\016DT_QUINT16_REF\020t\022\021\n\rDT_UIN" + "T16_REF\020u\022\025\n\021DT_COMPLEX128_REF\020v\022\017\n\013DT_H" + "ALF_REF\020w\022\023\n\017DT_RESOURCE_REF\020x\022\022\n\016DT_VAR" + "IANT_REF\020y\022\021\n\rDT_UINT32_REF\020z\022\021\n\rDT_UINT" + "64_REF\020{\022\026\n\022DT_FLOAT8_E5M2_REF\020|\022\030\n\024DT_F" + "LOAT8_E4M3FN_REF\020}\022\020\n\013DT_INT4_REF\020\201\001\022\021\n\014" + "DT_UINT4_REF\020\202\001Bz\n\030org.tensorflow.framew" + "orkB\013TypesProtosP\001ZLgithub.com/tensorflo" + "w/tensorflow/tensorflow/go/core/framewor" + "k/types_go_proto\370\001\001b\006proto3" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_types_2eproto_deps[1] = { +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_types_2eproto_sccs[1] = { + &scc_info_SerializedDType_types_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_types_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_types_2eproto = { + false, false, descriptor_table_protodef_types_2eproto, "types.proto", 1187, + &descriptor_table_types_2eproto_once, descriptor_table_types_2eproto_sccs, descriptor_table_types_2eproto_deps, 1, 0, + schemas, file_default_instances, TableStruct_types_2eproto::offsets, + file_level_metadata_types_2eproto, 1, file_level_enum_descriptors_types_2eproto, file_level_service_descriptors_types_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_types_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_types_2eproto)), true); +namespace tensorboard { +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataType_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_types_2eproto); + return file_level_enum_descriptors_types_2eproto[0]; +} +bool DataType_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + case 9: + case 10: + case 11: + case 12: + case 13: + case 14: + case 15: + case 16: + case 17: + case 18: + case 19: + case 20: + case 21: + case 22: + case 23: + case 24: + case 25: + case 29: + case 30: + case 101: + case 102: + case 103: + case 104: + case 105: + case 106: + case 107: + case 108: + case 109: + case 110: + case 111: + case 112: + case 113: + case 114: + case 115: + case 116: + case 117: + case 118: + case 119: + case 120: + case 121: + case 122: + case 123: + case 124: + case 125: + case 129: + case 130: + return true; + default: + return false; + } +} + + +// =================================================================== + +void SerializedDType::InitAsDefaultInstance() { +} +class SerializedDType::_Internal { + public: +}; + +SerializedDType::SerializedDType(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:tensorboard.SerializedDType) +} +SerializedDType::SerializedDType(const SerializedDType& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + datatype_ = from.datatype_; + // @@protoc_insertion_point(copy_constructor:tensorboard.SerializedDType) +} + +void SerializedDType::SharedCtor() { + datatype_ = 0; +} + +SerializedDType::~SerializedDType() { + // @@protoc_insertion_point(destructor:tensorboard.SerializedDType) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void SerializedDType::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void SerializedDType::ArenaDtor(void* object) { + SerializedDType* _this = reinterpret_cast< SerializedDType* >(object); + (void)_this; +} +void SerializedDType::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void SerializedDType::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const SerializedDType& SerializedDType::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SerializedDType_types_2eproto.base); + return *internal_default_instance(); +} + + +void SerializedDType::Clear() { +// @@protoc_insertion_point(message_clear_start:tensorboard.SerializedDType) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + datatype_ = 0; + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* SerializedDType::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .tensorboard.DataType datatype = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_datatype(static_cast<::tensorboard::DataType>(val)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* SerializedDType::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:tensorboard.SerializedDType) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // .tensorboard.DataType datatype = 1; + if (this->datatype() != 0) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_datatype(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:tensorboard.SerializedDType) + return target; +} + +size_t SerializedDType::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:tensorboard.SerializedDType) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // .tensorboard.DataType datatype = 1; + if (this->datatype() != 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_datatype()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void SerializedDType::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:tensorboard.SerializedDType) + GOOGLE_DCHECK_NE(&from, this); + const SerializedDType* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:tensorboard.SerializedDType) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:tensorboard.SerializedDType) + MergeFrom(*source); + } +} + +void SerializedDType::MergeFrom(const SerializedDType& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:tensorboard.SerializedDType) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.datatype() != 0) { + _internal_set_datatype(from._internal_datatype()); + } +} + +void SerializedDType::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:tensorboard.SerializedDType) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void SerializedDType::CopyFrom(const SerializedDType& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:tensorboard.SerializedDType) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool SerializedDType::IsInitialized() const { + return true; +} + +void SerializedDType::InternalSwap(SerializedDType* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(datatype_, other->datatype_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata SerializedDType::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::tensorboard::SerializedDType* Arena::CreateMaybeMessage< ::tensorboard::SerializedDType >(Arena* arena) { + return Arena::CreateMessageInternal< ::tensorboard::SerializedDType >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/plugins/mindstudio-insight-plugins/proto/types.pb.h b/plugins/mindstudio-insight-plugins/proto/types.pb.h new file mode 100644 index 0000000000000000000000000000000000000000..565c0e7a2a1bab54e2d447fb258e3e0c6b98c213 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/types.pb.h @@ -0,0 +1,335 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: types.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_types_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_types_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3013000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3013000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_types_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_types_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_types_2eproto; +namespace tensorboard { +class SerializedDType; +class SerializedDTypeDefaultTypeInternal; +extern SerializedDTypeDefaultTypeInternal _SerializedDType_default_instance_; +} // namespace tensorboard +PROTOBUF_NAMESPACE_OPEN +template<> ::tensorboard::SerializedDType* Arena::CreateMaybeMessage<::tensorboard::SerializedDType>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace tensorboard { + +enum DataType : int { + DT_INVALID = 0, + DT_FLOAT = 1, + DT_DOUBLE = 2, + DT_INT32 = 3, + DT_UINT8 = 4, + DT_INT16 = 5, + DT_INT8 = 6, + DT_STRING = 7, + DT_COMPLEX64 = 8, + DT_INT64 = 9, + DT_BOOL = 10, + DT_QINT8 = 11, + DT_QUINT8 = 12, + DT_QINT32 = 13, + DT_BFLOAT16 = 14, + DT_QINT16 = 15, + DT_QUINT16 = 16, + DT_UINT16 = 17, + DT_COMPLEX128 = 18, + DT_HALF = 19, + DT_RESOURCE = 20, + DT_VARIANT = 21, + DT_UINT32 = 22, + DT_UINT64 = 23, + DT_FLOAT8_E5M2 = 24, + DT_FLOAT8_E4M3FN = 25, + DT_INT4 = 29, + DT_UINT4 = 30, + DT_FLOAT_REF = 101, + DT_DOUBLE_REF = 102, + DT_INT32_REF = 103, + DT_UINT8_REF = 104, + DT_INT16_REF = 105, + DT_INT8_REF = 106, + DT_STRING_REF = 107, + DT_COMPLEX64_REF = 108, + DT_INT64_REF = 109, + DT_BOOL_REF = 110, + DT_QINT8_REF = 111, + DT_QUINT8_REF = 112, + DT_QINT32_REF = 113, + DT_BFLOAT16_REF = 114, + DT_QINT16_REF = 115, + DT_QUINT16_REF = 116, + DT_UINT16_REF = 117, + DT_COMPLEX128_REF = 118, + DT_HALF_REF = 119, + DT_RESOURCE_REF = 120, + DT_VARIANT_REF = 121, + DT_UINT32_REF = 122, + DT_UINT64_REF = 123, + DT_FLOAT8_E5M2_REF = 124, + DT_FLOAT8_E4M3FN_REF = 125, + DT_INT4_REF = 129, + DT_UINT4_REF = 130, + DataType_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::min(), + DataType_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::PROTOBUF_NAMESPACE_ID::int32>::max() +}; +bool DataType_IsValid(int value); +constexpr DataType DataType_MIN = DT_INVALID; +constexpr DataType DataType_MAX = DT_UINT4_REF; +constexpr int DataType_ARRAYSIZE = DataType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* DataType_descriptor(); +template +inline const std::string& DataType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function DataType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + DataType_descriptor(), enum_t_value); +} +inline bool DataType_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, DataType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + DataType_descriptor(), name, value); +} +// =================================================================== + +class SerializedDType PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:tensorboard.SerializedDType) */ { + public: + inline SerializedDType() : SerializedDType(nullptr) {} + virtual ~SerializedDType(); + + SerializedDType(const SerializedDType& from); + SerializedDType(SerializedDType&& from) noexcept + : SerializedDType() { + *this = ::std::move(from); + } + + inline SerializedDType& operator=(const SerializedDType& from) { + CopyFrom(from); + return *this; + } + inline SerializedDType& operator=(SerializedDType&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const SerializedDType& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const SerializedDType* internal_default_instance() { + return reinterpret_cast( + &_SerializedDType_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(SerializedDType& a, SerializedDType& b) { + a.Swap(&b); + } + inline void Swap(SerializedDType* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(SerializedDType* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline SerializedDType* New() const final { + return CreateMaybeMessage(nullptr); + } + + SerializedDType* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const SerializedDType& from); + void MergeFrom(const SerializedDType& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(SerializedDType* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "tensorboard.SerializedDType"; + } + protected: + explicit SerializedDType(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_types_2eproto); + return ::descriptor_table_types_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDatatypeFieldNumber = 1, + }; + // .tensorboard.DataType datatype = 1; + void clear_datatype(); + ::tensorboard::DataType datatype() const; + void set_datatype(::tensorboard::DataType value); + private: + ::tensorboard::DataType _internal_datatype() const; + void _internal_set_datatype(::tensorboard::DataType value); + public: + + // @@protoc_insertion_point(class_scope:tensorboard.SerializedDType) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + int datatype_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_types_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// SerializedDType + +// .tensorboard.DataType datatype = 1; +inline void SerializedDType::clear_datatype() { + datatype_ = 0; +} +inline ::tensorboard::DataType SerializedDType::_internal_datatype() const { + return static_cast< ::tensorboard::DataType >(datatype_); +} +inline ::tensorboard::DataType SerializedDType::datatype() const { + // @@protoc_insertion_point(field_get:tensorboard.SerializedDType.datatype) + return _internal_datatype(); +} +inline void SerializedDType::_internal_set_datatype(::tensorboard::DataType value) { + + datatype_ = value; +} +inline void SerializedDType::set_datatype(::tensorboard::DataType value) { + _internal_set_datatype(value); + // @@protoc_insertion_point(field_set:tensorboard.SerializedDType.datatype) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ + +// @@protoc_insertion_point(namespace_scope) + +} // namespace tensorboard + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< ::tensorboard::DataType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::tensorboard::DataType>() { + return ::tensorboard::DataType_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_types_2eproto diff --git a/plugins/mindstudio-insight-plugins/proto/types.proto b/plugins/mindstudio-insight-plugins/proto/types.proto new file mode 100644 index 0000000000000000000000000000000000000000..89d5411419ad7fda1375f6588208f34b07f4661f --- /dev/null +++ b/plugins/mindstudio-insight-plugins/proto/types.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package tensorboard; + +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"; + +// (== suppress_warning documentation-presence ==) +// DISABLED.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + DT_FLOAT8_E5M2 = 24; // 5 exponent bits, 2 mantissa bits. + DT_FLOAT8_E4M3FN = 25; // 4 exponent bits, 3 mantissa bits, finite-only, with + // 2 NaNs (0bS1111111). + // TODO - b/299182407: Leaving room for remaining float8 types. + // DT_FLOAT8_E4M3FNUZ = 26; + // DT_FLOAT8_E4M3B11FNUZ = 27; + // DT_FLOAT8_E5M2FNUZ = 28; + DT_INT4 = 29; + DT_UINT4 = 30; + + // Do not use! These are only for TF1's obsolete reference Variables. + // Every enum above should have a corresponding value below (verified by + // types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; + DT_FLOAT8_E5M2_REF = 124; + DT_FLOAT8_E4M3FN_REF = 125; + // TODO - b/299182407: Leaving room for remaining float8 types. + // DT_FLOAT8_E4M3FNUZ_REF = 126; + // DT_FLOAT8_E4M3B11FNUZ_REF = 127; + // DT_FLOAT8_E5M2FNUZ_REF = 128; + DT_INT4_REF = 129; + DT_UINT4_REF = 130; +} +// DISABLED.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/tf_datatype.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorboard/compat/proto/tensor.cc, +// https://www.tensorflow.org/code/tensorboard/compat/proto/types.h, +// https://www.tensorflow.org/code/tensorboard/compat/proto/types.cc, +// https://www.tensorflow.org/code/tensorboard/compat/proto/dtypes.py, +// https://www.tensorflow.org/code/tensorboard/compat/proto/function.py) + +// Represents a serialized tf.dtypes.Dtype +message SerializedDType { + DataType datatype = 1; +} diff --git a/plugins/mindstudio-insight-plugins/rapidjson.tar.gz b/plugins/mindstudio-insight-plugins/rapidjson.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..bdc395113f547ab18c820df65e055b34c329d131 Binary files /dev/null and b/plugins/mindstudio-insight-plugins/rapidjson.tar.gz differ diff --git a/plugins/mindstudio-insight-plugins/tools/httpServer/CMakeLists.txt b/plugins/mindstudio-insight-plugins/tools/httpServer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a4cb12eddaa91593942470dbb5005912ff676c95 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/tools/httpServer/CMakeLists.txt @@ -0,0 +1,25 @@ +project(HttpServer) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +set(CMAKE_BUILD_TYPE Debug CACHE STRING "Build type" FORCE) + +set(EXECUTABLE_OUTPUT_PATH ${PROJECT_ROOT_DIR}/output) +aux_source_directory(. MAIN_SRC) +list(APPEND MAIN_SRC + ${MAIN_SRC} + ${U_SOCKETS_SRC}) +add_executable(${PROJECT_NAME} ${MAIN_SRC} + ../../plugin_core/src/Logger.cpp) +#if (CMAKE_BUILD_TYPE STREQUAL "Debug") +target_compile_options(${PROJECT_NAME} PUBLIC -g -O0) +#endif () +target_compile_definitions(${PROJECT_NAME} PRIVATE PLUGINS_DIR="${PROJECT_ROOT_DIR}/output/plugins") +target_include_directories(${PROJECT_NAME} PUBLIC ${MAIN_INCLUDE}) +target_include_directories(${PROJECT_NAME} PUBLIC ${PROJECT_ROOT_DIR}/plugin_core/include) +target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_ROOT_DIR}/output/lib/libmsinsight.so) +target_include_directories(${PROJECT_NAME} PRIVATE ${PROJECT_ROOT_DIR}/rapidjson/include/rapidjson) + +target_link_libraries(${PROJECT_NAME} PUBLIC uv_a) +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_libraries(${PROJECT_NAME} PUBLIC stdc++fs) +endif () diff --git a/plugins/mindstudio-insight-plugins/tools/httpServer/HttpServer.cpp b/plugins/mindstudio-insight-plugins/tools/httpServer/HttpServer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9da0f552cacdda4a64aa0777f7cf9bbab20861e0 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/tools/httpServer/HttpServer.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#include +#include +#include "stringbuffer.h" +#include "writer.h" +#include "HttpServer.h" +#include "PluginsManager.h" +#include "Logger.h" + +using namespace Insight; +namespace Insight::Http { +using json_t = rapidjson::Value; +using document_t = rapidjson::Document; + +HttpServer &HttpServer::Instance() { + static HttpServer instance; + return instance; +} + +static inline std::string DumpJsonToStr(json_t &jsonSrc) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + jsonSrc.Accept(writer); + return {buffer.GetString(), buffer.GetSize()}; +} + +bool HttpServer::Start() { + wsApp = std::make_unique(); + AddApiHandlers(); + wsApp->options("/*", [](auto *res, auto *req) { + res->end(); + }); + wsApp->listen("0.0.0.0", port, [](auto *token) { + if (token) { + LOG(LogRank::Info) << "http server start"; + } + }).run(); + return false; +} + +void HttpServer::AddApiHandlers() { + wsApp->get("test", [this](HttpResponse *res, auto *req) { + std::cout << "test" << std::endl; + res->tryEnd("", 0); + }); + auto &manager = Dic::Core::PluginsManager::Instance(); + Dic::Core::PluginsManager::LoadPlugins(); + for (const auto &[name, plugin]: manager.GetAllPlugins()) { + auto handlers = plugin->GetAllHandlers(); + for (const auto &[key, handler]: handlers) { + std::cout << "Add Handler:" << key << std::endl; + if (handler->GetApiType() == Dic::Core::API_TYPE::GET) { + std::cout << "Add Handler2" << std::endl; + AddGetHandler(std::string("/" + name + "/" + key), handler); + } else { + AddPostHandler(std::string("/" + name + "/" + key), handler); + } + } + } +} + +void HttpServer::AddGetHandler(std::string_view key, const std::shared_ptr handler) { + wsApp->get(key.data(), [handler, this](HttpResponse *res, HttpRequest *req) { + res->writeHeader("Access-Control-Allow-Origin", "*"); // allow CROS request + res->writeHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); + res->writeHeader("Access-Control-Allow-Headers", "Content-Type"); + res->writeHeader("Access-Control-Allow-Credentials", "true"); + std::string result = GetBasicResult(); + if (handler->run(req->getQuery(), result)) { + } + res->tryEnd(result, result.size()); + }); +} + +void HttpServer::AddPostHandler(std::string_view key, const std::shared_ptr handler) { + wsApp->post(key.data(), [handler, this](HttpResponse *res, auto *req) { + res->onAborted([]() { + Loop::get()->defer([]() { + }); + }); + res->onData([res, handler, bodyBuffer = std::string(), this](std::string_view data, bool isEnd) mutable { + bodyBuffer.append(data); + if (isEnd) { + // 处理数据 + res->writeHeader("Access-Control-Allow-Origin", "*"); // allow CROS request + res->writeHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); + res->writeHeader("Access-Control-Allow-Headers", "Content-Type"); + res->writeHeader("Access-Control-Allow-Credentials", "true"); + std::string result = GetBasicResult(); + bool sucess = handler->run(bodyBuffer, result); + // long data length can't send rightly with try end + res->end(result, true); + LOG(LogRank::Info) << "Response size:" << result.size(); + } + }); + }); +} + +std::string HttpServer::GetBasicResult() { + return R"({"body":{}, "msg":"", "errCode":0, "result":true})"; +} +} diff --git a/plugins/mindstudio-insight-plugins/tools/httpServer/HttpServer.h b/plugins/mindstudio-insight-plugins/tools/httpServer/HttpServer.h new file mode 100644 index 0000000000000000000000000000000000000000..52f468ade87afba8d5d0564c6c8a2be530a5d7d0 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/tools/httpServer/HttpServer.h @@ -0,0 +1,40 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ +#ifndef BOARD_HTTPSERVER_H +#define BOARD_HTTPSERVER_H + +#include +#include +#include "App.h" +#include "ApiHandler.h" + +namespace Insight::Http { +using namespace uWS; +using namespace Dic::Core; + +class HttpServer { +public: + static HttpServer &Instance(); + + bool Start(); + +private: + HttpServer() = default; + + ~HttpServer() = default; + + void AddApiHandlers(); + + void AddGetHandler(std::string_view key, const std::shared_ptr handler); + + void AddPostHandler(std::string_view key, const std::shared_ptr handler); + + static std::string GetBasicResult(); + + std::unique_ptr wsApp; + uint16_t port{6065}; +}; +} + +#endif // BOARD_HTTPSERVER_H diff --git a/plugins/mindstudio-insight-plugins/tools/httpServer/main.cpp b/plugins/mindstudio-insight-plugins/tools/httpServer/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..354b7384000ab5c053c3928794dd9641c09e1304 --- /dev/null +++ b/plugins/mindstudio-insight-plugins/tools/httpServer/main.cpp @@ -0,0 +1,11 @@ +/* + * Copyright (c), Huawei Technologies Co., Ltd. 2024-2024.All rights reserved. + */ + +#include "HttpServer.h" + +using namespace Dic::Core; + +int main(int argc, char *argv[]) { + Insight::Http::HttpServer::Instance().Start(); +} \ No newline at end of file diff --git a/plugins/mindstudio-insight-plugins/uSockets-0.8.6.tar.gz b/plugins/mindstudio-insight-plugins/uSockets-0.8.6.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..74734b0b84112b1c524d6ab36217902253876c76 Binary files /dev/null and b/plugins/mindstudio-insight-plugins/uSockets-0.8.6.tar.gz differ diff --git a/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml b/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml deleted file mode 100644 index 3133d6400fb0b3ca0ee9b38c311c2db6d1167c7e..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: LIBKINETOCI - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - - steps: - - uses: actions/checkout@v2 - - name: Checkout submodules - shell: bash - run: | - auth_header="$(git config --local --get http.https://github.com/.extraheader)" - git submodule sync --recursive - git -c "http.extraheader=$auth_header" -c protocol.version=2 submodule update --init --force --recursive --depth=1 - - - name: Get env vars - run: | - echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW - echo HOME = $HOME - echo GITHUB_ACTION = $GITHUB_ACTION - echo GITHUB_ACTIONS = $GITHUB_ACTIONS - echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY - echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME - echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH - echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE - echo GITHUB_SHA = $GITHUB_SHA - echo GITHUB_REF = $GITHUB_REF - c++ --verbose - - # TODO: Figure out how to install cupti headers T84637671 - - name: Build static lib - run: | - set -e - mkdir build_static - cd build_static - cmake -DKINETO_LIBRARY_TYPE=static ../libkineto/ - make -j - - - name: Build shared lib - run: | - set -e - mkdir build_shared - cd build_shared - cmake -DKINETO_LIBRARY_TYPE=shared ../libkineto/ - make -j diff --git a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml deleted file mode 100644 index 9bdafcc442635eaff19fc7a7505f5231cf6e5cf7..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Build torch-tb-profiler Pip Package - -on: - # TODO: Add an on_release trigger to build on tags - workflow_dispatch: - -jobs: - build-package: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: build pip package - run: | - set -e - cd tb_plugin - python setup.py sdist bdist_wheel - cd dist/ - pip install *.whl - python -c "import torch_tb_profiler;print(torch_tb_profiler.__version__)" diff --git a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml deleted file mode 100644 index 1b59a7bf90a6009caa41d4ac0e3d5545dc8b6c7c..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: TB_Plugin_CI - -on: - push: - branches: - - main - - release/** - - plugin/** - - pull_request: - branches: - - main - - release/** - - plugin/** - -jobs: - generate-matrix: - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set-matrix.outputs.matrix }} - steps: - - id: set-matrix - run: | - echo $GITHUB_BASE_REF - if [ $GITHUB_BASE_REF == "plugin/vnext" ] - then - echo "::set-output name=matrix::{\"python-version\":[3.7, 3.8, 3.9], \"cuda-version\":[\"cpu\"], \"pytorch-version\":[\"nightly\"]}" - else - echo "::set-output name=matrix::{\"python-version\":[3.7, 3.8, 3.9], \"cuda-version\":[\"cpu\"], \"pytorch-version\":[\"nightly\", \"1.11rc\", \"stable\"]}" - fi - - build: - needs: generate-matrix - runs-on: ubuntu-latest - strategy: - matrix: ${{fromJSON(needs.generate-matrix.outputs.matrix)}} - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - architecture: 'x64' - - name: Test - env: - CUDA_VERSION: ${{ matrix.cuda-version }} - PYTORCH_VERSION: ${{ matrix.pytorch-version }} - TORCH_PROFILER_LOG_LEVEL: DEBUG - GRPC_VERBOSITY: DEBUG - GRPC_ENABLE_FORK_SUPPORT: 'False' - run: | - set -e - cd tb_plugin - sh ./ci_scripts/install_env.sh - pip install .[gs] - cd test - pytest diff --git a/plugins/tensorboard-plugins/.gitignore b/plugins/tensorboard-plugins/.gitignore deleted file mode 100644 index ce186381c0b566e0ca225be70cbf8ac233d7aa6b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# ignore common items -.idea -.vscode diff --git a/plugins/tensorboard-plugins/.gitmodules b/plugins/tensorboard-plugins/.gitmodules deleted file mode 100644 index 4660ee8bc9e6a4be4f4fbb007b8e66058122d716..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/.gitmodules +++ /dev/null @@ -1,6 +0,0 @@ -[submodule "libkineto/third_party/googletest"] - path = libkineto/third_party/googletest - url = https://github.com/google/googletest.git -[submodule "libkineto/third_party/fmt"] - path = libkineto/third_party/fmt - url = https://github.com/fmtlib/fmt.git diff --git a/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md b/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md deleted file mode 100644 index a0cbeaab7650bf08267fbdbc9bb54e845c88f392..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,77 +0,0 @@ -# Code of Conduct - -## Our Pledge - -In the interest of fostering an open and welcoming environment, we as -contributors and maintainers pledge to make participation in our project and -our community a harassment-free experience for everyone, regardless of age, body -size, disability, ethnicity, sex characteristics, gender identity and expression, -level of experience, education, socio-economic status, nationality, personal -appearance, race, religion, or sexual identity and orientation. - -## Our Standards - -Examples of behavior that contributes to creating a positive environment -include: - -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members - -Examples of unacceptable behavior by participants include: - -* The use of sexualized language or imagery and unwelcome sexual attention or - advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic - address, without explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Our Responsibilities - -Project maintainers are responsible for clarifying the standards of acceptable -behavior and are expected to take appropriate and fair corrective action in -response to any instances of unacceptable behavior. - -Project maintainers have the right and responsibility to remove, edit, or -reject comments, commits, code, wiki edits, issues, and other contributions -that are not aligned to this Code of Conduct, or to ban temporarily or -permanently any contributor for other behaviors that they deem inappropriate, -threatening, offensive, or harmful. - -## Scope - -This Code of Conduct applies within all project spaces, and it also applies when -an individual is representing the project or its community in public spaces. -Examples of representing a project or community include using an official -project e-mail address, posting via an official social media account, or acting -as an appointed representative at an online or offline event. Representation of -a project may be further defined and clarified by project maintainers. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported by contacting the project team at . All -complaints will be reviewed and investigated and will result in a response that -is deemed necessary and appropriate to the circumstances. The project team is -obligated to maintain confidentiality with regard to the reporter of an incident. -Further details of specific enforcement policies may be posted separately. - -Project maintainers who do not follow or enforce the Code of Conduct in good -faith may face temporary or permanent repercussions as determined by other -members of the project's leadership. - -## Attribution - -This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, -available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see -https://www.contributor-covenant.org/faq - diff --git a/plugins/tensorboard-plugins/CONTRIBUTING.md b/plugins/tensorboard-plugins/CONTRIBUTING.md deleted file mode 100644 index a2e931bb6f0cc82ff030cee10ee1c99fbbbda07b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/CONTRIBUTING.md +++ /dev/null @@ -1,34 +0,0 @@ -# Contributing to Kineto -We want to make contributing to this project as easy and transparent as -possible. - -## Code of Conduct -The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md). - -## Pull Requests -We actively welcome your pull requests. - -1. Fork the repo and create your branch from `main`. -2. If you've added code that should be tested, add tests. -3. If you've changed APIs, update the documentation. -4. Ensure the test suite passes. -5. Make sure your code lints. -6. If you haven't already, complete the Contributor License Agreement ("CLA"). - -## Contributor License Agreement ("CLA") -In order to accept your pull request, we need you to submit a CLA. You only need -to do this once to work on any of Facebook's open source projects. - -Complete your CLA here: - -## Issues -We use GitHub issues to track public bugs. Please ensure your description is -clear and has sufficient instructions to be able to reproduce the issue. - -Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe -disclosure of security bugs. In those cases, please go through the process -outlined on that page and do not file a public issue. - -## License -By contributing to Kineto, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. diff --git a/plugins/tensorboard-plugins/LICENSE b/plugins/tensorboard-plugins/LICENSE deleted file mode 100644 index edb179715b5213644cfe903d43294f54892e707e..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/LICENSE +++ /dev/null @@ -1,33 +0,0 @@ -BSD License - -For Kineto software - -Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -All contributions by Microsoft: -Copyright (c) Microsoft Corporation. (The Azure AI Platform team) - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - * Neither the name Facebook nor the names of its contributors may be used to - endorse or promote products derived from this software without specific - prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/plugins/tensorboard-plugins/README.md b/plugins/tensorboard-plugins/README.md deleted file mode 100644 index 3a18f4c6239f353c10362c9e0ba5aae052cb2c07..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Kineto - -Kineto is part of the PyTorch Profiler. - -The Kineto project was started to help enable -- **performance observability and diagnostics** across common ML bottleneck components -- **actionable recommendations** for common issues -- integration of external system-level profiling tools -- integration with popular visualization platforms and analysis pipelines - -A central component is libkineto, a profiling library with special focus on low-overhead GPU timeline tracing. - -The PyTorch Profiler TensorBoard plugin provides powerful and intuitive visualizations of profiling results, as well as actionable recommendations, and is the best way to experience the new PyTorch Profiler. - -## Libkineto -Libkineto is an in-process profiling library integrated with the PyTorch Profiler. Please refer to the [README](libkineto/README.md) file in the `libkineto` folder as well as documentation on the [new PyTorch Profiler API](https://pytorch.org/docs/master/profiler.html). - -## PyTorch TensorBoard Profiler NPU Plugin -The goal of the PyTorch TensorBoard Profiler is to provide a seamless and intuitive end-to-end profiling experience, including straightforward collection from PyTorch and insightful visualizations and recommendations in the TensorBoard UI. -Please refer to the [README](tb_plugin/README.md) file in the `tb_plugin` folder. - -## Future Development Direction: -Some areas we're currently working on: -- Support for tracing distributed workloads -- Trace processing, analysis and recommendation engine -- System-level activities, multiple tracing sources -- Profiling and monitoring daemon for larger scale deployments - -## Releases and Contributing -We will follow the PyTorch release schedule which roughly happens on a 3 month basis. - -We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. - -If you plan to contribute new features, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the infrastructure in a different direction than you might be aware of. We expect the architecture to keep evolving. - -## License -Kineto has a BSD-style license, as found in the [LICENSE](LICENSE) file. - diff --git a/plugins/tensorboard-plugins/libkineto/CMakeLists.txt b/plugins/tensorboard-plugins/libkineto/CMakeLists.txt deleted file mode 100644 index 63966de803a786913b104419776aa94bb00b74b0..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/CMakeLists.txt +++ /dev/null @@ -1,198 +0,0 @@ -cmake_minimum_required(VERSION 3.5 FATAL_ERROR) - -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") - -#install libraries into correct locations on all platforms -include(GNUInstallDirs) - -# function to extract filelists from libkineto_defs.bzl file -find_package(PythonInterp) -function(get_filelist name outputvar) - execute_process( - COMMAND "${PYTHON_EXECUTABLE}" -c - "exec(open('libkineto_defs.bzl').read());print(';'.join(${name}))" - WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" - OUTPUT_VARIABLE _tempvar) - string(REPLACE "\n" "" _tempvar "${_tempvar}") - set(${outputvar} ${_tempvar} PARENT_SCOPE) -endfunction() - -project(kineto VERSION 0.1 LANGUAGES CXX C) - -set(KINETO_LIBRARY_TYPE "default" CACHE STRING - "Type of library (default, static or shared) to build") -set_property(CACHE KINETO_LIBRARY_TYPE PROPERTY STRINGS default shared) -option(KINETO_BUILD_TESTS "Build kineto unit tests" ON) - -set(LIBKINETO_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src") -set(LIBKINETO_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include") -set(LIBKINETO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) -set(LIBKINETO_THIRDPARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party") -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -#We should default to a Release build -if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "") - set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE) -endif() - -if (NOT CUDA_SOURCE_DIR) - set(CUDA_SOURCE_DIR "$ENV{CUDA_SOURCE_DIR}") - message(INFO " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}") -endif() - -if (NOT ROCM_SOURCE_DIR) - set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}") - message(INFO " ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}") -endif() - -# Set LIBKINETO_NOCUPTI to explicitly disable CUPTI -# Otherwise, CUPTI is disabled if not found -IF (NOT CUDA_SOURCE_DIR OR NOT CUPTI_INCLUDE_DIR OR NOT CUDA_cupti_LIBRARY) - set(LIBKINETO_NOCUPTI ON CACHE BOOL "" FORCE) -endif() - -IF (NOT ROCM_SOURCE_DIR AND NOT ROCTRACER_INCLUDE_DIR) - set(LIBKINETO_NOROCTRACER ON CACHE BOOL "" FORCE) -endif() - -# Define file lists -if (LIBKINETO_NOCUPTI AND LIBKINETO_NOROCTRACER) - get_filelist("get_libkineto_cpu_only_srcs(with_api=False)" LIBKINETO_SRCS) - message(INFO " CUPTI unavailable or disabled - not building GPU profilers") -elseif(NOT LIBKINETO_NOROCTRACER) - get_filelist("get_libkineto_roctracer_srcs()" LIBKINETO_SRCS) - message(INFO " Building with roctracer") -else() - get_filelist("get_libkineto_cupti_srcs(with_api=False)" LIBKINETO_SRCS) -endif() -get_filelist("get_libkineto_public_headers()" LIBKINETO_PUBLIC_HEADERS) -get_filelist("get_libkineto_api_srcs()" LIBKINETO_API_SRCS) - -add_library(kineto_base OBJECT ${LIBKINETO_SRCS}) -add_library(kineto_api OBJECT ${LIBKINETO_API_SRCS}) - -# Make libraries depend on libkineto_defs.bzl -add_custom_target(libkineto_defs.bzl DEPENDS libkineto_defs.bzl) -add_dependencies(kineto_base libkineto_defs.bzl) - -set_target_properties(kineto_base kineto_api PROPERTIES - CXX_STANDARD 14 - CXX_STANDARD_REQUIRED YES - CXX_EXTENSIONS NO - CXX_VISIBILITY_PRESET hidden) - -set(KINETO_COMPILE_OPTIONS "-DKINETO_NAMESPACE=libkineto") -list(APPEND KINETO_COMPILE_OPTIONS "-DFMT_HEADER_ONLY") -if(NOT MSVC) - list(APPEND KINETO_COMPILE_OPTIONS "-std=c++14") -else() - list(APPEND KINETO_COMPILE_OPTIONS "/std:c++14") - list(APPEND KINETO_COMPILE_OPTIONS "-DWIN32_LEAN_AND_MEAN") - list(APPEND KINETO_COMPILE_OPTIONS "-DNOGDI") -endif() -if (NOT LIBKINETO_NOCUPTI) - list(APPEND KINETO_COMPILE_OPTIONS "-DHAS_CUPTI") -endif() -if (NOT LIBKINETO_NOROCTRACER) - target_compile_options(kineto_base PRIVATE "-DHAS_ROCTRACER") - target_compile_options(kineto_base PRIVATE "-D__HIP_PLATFORM_HCC__") - target_compile_options(kineto_base PRIVATE "-D__HIP_PLATFORM_AMD__") -endif() - -target_compile_options(kineto_base PRIVATE "${KINETO_COMPILE_OPTIONS}") -target_compile_options(kineto_api PRIVATE "${KINETO_COMPILE_OPTIONS}") - -if(NOT TARGET fmt) - if(NOT FMT_SOURCE_DIR) - set(FMT_SOURCE_DIR "${LIBKINETO_THIRDPARTY_DIR}/fmt" - CACHE STRING "fmt source directory from submodules") - endif() - - # Build FMT. - # FMT and some other libraries use BUILD_SHARED_LIBS to control - # the library type. - # Save and restore the value after configuring FMT - set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) - set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) - set(FMT_LIBRARY_TYPE static CACHE STRING "Set lib type to static") - add_subdirectory("${FMT_SOURCE_DIR}" "${LIBKINETO_BINARY_DIR}/fmt") - set_property(TARGET fmt PROPERTY POSITION_INDEPENDENT_CODE ON) - set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) -endif() - -set(FMT_INCLUDE_DIR "${FMT_SOURCE_DIR}/include") -message(STATUS "Kineto: FMT_SOURCE_DIR = ${FMT_SOURCE_DIR}") -message(STATUS "Kineto: FMT_INCLUDE_DIR = ${FMT_INCLUDE_DIR}") -if (NOT CUPTI_INCLUDE_DIR) - set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include") -endif() -if (NOT CUDA_INCLUDE_DIRS) - set(CUDA_INCLUDE_DIRS "${CUDA_SOURCE_DIR}/include") -endif() -if (NOT ROCTRACER_INCLUDE_DIR) - set(ROCTRACER_INCLUDE_DIR "${ROCM_SOURCE_DIR}/roctracer/include") -endif() -if (NOT ROCM_INCLUDE_DIRS) - set(ROCM_INCLUDE_DIRS "${ROCM_SOURCE_DIR}/include") -endif() - -message(INFO " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}") -message(INFO " ROCTRACER_INCLUDE_DIR = ${ROCTRACER_INCLUDE_DIR}") - -target_include_directories(kineto_base PUBLIC - $ - $ - $ - $ - $ - $ - $) - -target_include_directories(kineto_api PUBLIC - $ - $) - -if(KINETO_LIBRARY_TYPE STREQUAL "default") - add_library(kineto - $ - $) -elseif(KINETO_LIBRARY_TYPE STREQUAL "static") - add_library(kineto STATIC - $ - $) -elseif(KINETO_LIBRARY_TYPE STREQUAL "shared") - add_library(kineto SHARED - $) - set_property(TARGET kineto_base PROPERTY POSITION_INDEPENDENT_CODE ON) - set_target_properties(kineto PROPERTIES - CXX_VISIBILITY_PRESET hidden) -else() - message(FATAL_ERROR "Unsupported library type ${KINETO_LIBRARY_TYPE}") -endif() - -if(NOT LIBKINETO_NOROCTRACER) - find_library(ROCTRACER_LIBRARY NAMES libroctracer64.so HINTS /opt/rocm/roctracer/lib) - target_link_libraries(kineto "${ROCTRACER_LIBRARY}") - find_library(KINETO_HIP_LIBRARY NAMES libamdhip64.so HINTS /opt/rocm/lib) - target_link_libraries(kineto "${KINETO_HIP_LIBRARY}") -endif() - -if(NOT LIBKINETO_NOCUPTI) - target_link_libraries(kineto "${CUDA_cupti_LIBRARY}") -endif() -target_link_libraries(kineto $) -add_dependencies(kineto fmt::fmt-header-only) - -install(TARGETS kineto EXPORT kinetoLibraryConfig - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) - -install(FILES ${LIBKINETO_PUBLIC_HEADERS} - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kineto") - -install(EXPORT kinetoLibraryConfig DESTINATION share/cmake/kineto - FILE kinetoLibraryConfig.cmake) - -if(KINETO_BUILD_TESTS) - add_subdirectory(test) -endif() diff --git a/plugins/tensorboard-plugins/libkineto/README.md b/plugins/tensorboard-plugins/libkineto/README.md deleted file mode 100644 index 37127ca5aa821217da48aad38cb82eb36f8735c2..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# Libkineto - -Libkineto is an in-process profiling library, part of the Kineto performance -tools project. - -The library provides a way to collect GPU traces and metrics from the host -process, either via the library public API or by sending a signal, if enabled. - -Currently only NVIDIA GPUs are supported. - -## Build Notes -Libkineto uses the standard CMAKE-based build flow. - -### Dependencies -Libkineto requires gcc 5+ and: - -- NVIDIA CUPTI: used to collect traces and metrics from NVIDIA GPUs. -- fmt: used for its convenient and lightweight string formatting functionality. -- googletest: required to build and run Kineto's tests. - - **googletest is not required** if you don't want to run Kineto tests. -By default, building of tests is **on**. Turn it off by setting `KINETO_BUILD_TESTS` to **off**. - -You can download [NVIDIA CUPTI][1], [fmt][2], [googletest][3] and set -`CUDA_SOURCE_DIR`, `FMT_SOURCE_DIR`, `GOOGLETEST_SOURCE_DIR` respectively for -cmake to find these libraries. If the fmt and googletest variables are not set, cmake will -build the git submodules found in the `third_party` directory. -If `CUDA_SOURCE_DIR` is not set, libkineto will fail to build. - -### Building Libkineto - -``` -# Check out repo and sub modules -git clone --recursive https://github.com/pytorch/kineto.git -# Build libkineto with cmake -cd kineto/libkineto -mkdir build && cd build -cmake .. -make -``` - -To run the tests after building libkineto (if tests are built), use the following -command: -``` -make test -``` - -### Installing Libkineto -``` -make install -``` - -## How Libkineto works -We will provide a high-level overview, design philosophy and brief descriptions of various -parts of Libkineto in upcoming blogs. - -## Full documentation -We strive to keep our source files readable. The best and up-to-date -documentation is available in the source files. - -## License -Libkineto is BSD licensed, as detailed in the [LICENSE](../LICENSE) file. - -[1]:https://developer.nvidia.com/CUPTI-CTK10_2 -[2]:https://github.com/fmt -[3]:https://github.com/google/googletest diff --git a/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h b/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h deleted file mode 100644 index 1cadf4906c11c3b5f59e290295048cee7fd63acf..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h +++ /dev/null @@ -1,113 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include - -namespace KINETO_NAMESPACE { - -class AbstractConfig { - public: - AbstractConfig& operator=(const AbstractConfig&) = delete; - AbstractConfig(AbstractConfig&&) = delete; - AbstractConfig& operator=(AbstractConfig&&) = delete; - - virtual ~AbstractConfig() { - for (const auto& p : featureConfigs_) { - delete p.second; - } - } - - // Return a copy of the full derived class - virtual AbstractConfig* cloneDerived(AbstractConfig& parent) const = 0; - - // Returns true if successfully parsed the config string - bool parse(const std::string& conf); - - // Default setup for signal-triggered profiling - virtual void setSignalDefaults() { - for (auto& p : featureConfigs_) { - p.second->setSignalDefaults(); - } - } - - // Default setup for client-triggered profiling - virtual void setClientDefaults() { - for (auto& p : featureConfigs_) { - p.second->setClientDefaults(); - } - } - - // Time config was created / updated - std::chrono::time_point timestamp() const { - return timestamp_; - } - - // Source config string that this was parsed from - const std::string& source() const { - return source_; - } - - AbstractConfig& feature(std::string name) const { - const auto& pos = featureConfigs_.find(name); - return *pos->second; - } - - // Transfers ownership of cfg arg - void addFeature(const std::string& name, AbstractConfig* cfg) { - featureConfigs_[name] = cfg; - } - - protected: - AbstractConfig() {} - AbstractConfig(const AbstractConfig& other) = default; - - // Return true if the option was recognized and successfully parsed. - // Throw std::invalid_argument if val is invalid. - virtual bool handleOption(const std::string& name, std::string& val); - - // Perform post-validation checks, typically conditons involving - // multiple options. - // Throw std::invalid_argument if automatic correction can not be made. - // - // @param fallbackProfileStartTime Specify a fallback profile start timestamp in case it was never specified by the client - virtual void validate(const std::chrono::time_point& fallbackProfileStartTime) = 0; - - // TODO: Separate out each profiler type into features? - virtual void printActivityProfilerConfig(std::ostream& s) const; - - // Helpers for use in handleOption - // Split a string by delimiter and remove external white space - std::vector splitAndTrim(const std::string& s, char delim) const; - // Lowercase for case-insensitive comparisons - std::string toLower(std::string& s) const; - // Does string end with suffix - bool endsWith(const std::string& s, const std::string& suffix) const; - // Conversions - int64_t toIntRange(const std::string& val, int64_t min, int64_t max) const; - int32_t toInt32(const std::string& val) const; - int64_t toInt64(const std::string& val) const; - bool toBool(std::string& val) const; - - void cloneFeaturesInto(AbstractConfig& cfg) const { - for (const auto& feature : featureConfigs_) { - cfg.featureConfigs_[feature.first] = feature.second->cloneDerived(cfg); - } - } - - private: - // Time config was created / updated - std::chrono::time_point timestamp_{}; - - // Original configuration string, used for comparison - std::string source_{""}; - - // Configuration objects for optional features - std::map featureConfigs_{}; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h b/plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h deleted file mode 100644 index 29871e47ab8af87888ccb8e20403bc26c433b5cc..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h +++ /dev/null @@ -1,91 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include - -#include "ActivityType.h" -#include "ActivityTraceInterface.h" -#include "IActivityProfiler.h" - -namespace libkineto { - -class ActivityProfilerController; -struct CpuTraceBuffer; -class Config; - -class ActivityProfilerInterface { - - public: - virtual ~ActivityProfilerInterface() {}; - - virtual void init() {} - virtual bool isInitialized() { - return false; - } - virtual bool isActive(){ - return false; - } - - // *** Asynchronous API *** - // Instead of starting and stopping the trace manually, provide a start time - // and duration and / or iteration stop criterion. - // Tracing terminates when either condition is met. - virtual void scheduleTrace(const std::string& configStr) {} - - // *** Synchronous API *** - // These must be called in order: - // prepareTrace -> startTrace -> stopTrace. - - // Many tracing structures are lazily initialized during trace collection, - // with potentially high overhead. - // Call prepareTrace to enable tracing, then run the region to trace - // at least once (and ideally run the same code that is to be traced) to - // allow tracing structures to be initialized. - virtual void prepareTrace( - const std::set& activityTypes, - const std::string& configStr = "") {} - - // Start recording, potentially reusing any buffers allocated since - // prepareTrace was called. - virtual void startTrace() {} - - // Stop and process trace, producing an in-memory list of trace records. - // The processing will be done synchronously (using the calling thread.) - virtual std::unique_ptr stopTrace() { - return nullptr; - } - - // Re-evaluate internal state to allow for triggering operations based - // on number of iteration. each implicitly increments the iteration count - virtual void step() {} - - // *** TraceActivity API *** - // FIXME: Pass activityProfiler interface into clientInterface? - virtual void pushCorrelationId(uint64_t id){} - virtual void popCorrelationId(){} - virtual void transferCpuTrace( - std::unique_ptr traceBuffer){} - - // Correlation ids for user defined spans - virtual void pushUserCorrelationId(uint64_t){} - virtual void popUserCorrelationId(){} - - // Saves information for the current thread to be used in profiler output - // Client must record any new kernel thread where the activity has occured. - virtual void recordThreadInfo() {} - - // Record trace metadata, currently supporting only string key and values, - // values with the same key are overwritten - virtual void addMetadata(const std::string& key, const std::string& value) = 0; - - // Add a child activity profiler, this enables frameworks in the application - // to enable custom framework events. - virtual void addChildActivityProfiler( - std::unique_ptr profiler) {} -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h b/plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h deleted file mode 100644 index 23d4edab00ce2fa90427e13818ac09c8541835ac..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h +++ /dev/null @@ -1,21 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include - -namespace libkineto { - -struct ITraceActivity; - -class ActivityTraceInterface { - public: - virtual ~ActivityTraceInterface() {} - virtual const std::vector* activities() { - return nullptr; - } - virtual void save(const std::string& path) {} -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ActivityType.h b/plugins/tensorboard-plugins/libkineto/include/ActivityType.h deleted file mode 100644 index 74c6a2531d6a9cee3196f9f889517926afea823f..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ActivityType.h +++ /dev/null @@ -1,34 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include - -namespace libkineto { - -enum class ActivityType { - CPU_OP = 0, // cpu side ops - USER_ANNOTATION, - GPU_USER_ANNOTATION, - GPU_MEMCPY, - GPU_MEMSET, - CONCURRENT_KERNEL, // on-device kernels - EXTERNAL_CORRELATION, - CUDA_RUNTIME, // host side cuda runtime events - CUDA_PROFILER_RANGE, // CUPTI Profiler range for performance metrics - GLOW_RUNTIME, // host side glow runtime events - CPU_INSTANT_EVENT, // host side point-like events - PYTHON_FUNCTION, - OVERHEAD, // CUPTI induced overhead events sampled from its overhead API. - ENUM_COUNT // This is to add buffer and not used for any profiling logic. Add your new type before it. -}; - -const char* toString(ActivityType t); -ActivityType toActivityType(const std::string& str); - -// Return an array of all activity types except COUNT -constexpr int activityTypeCount = (int)ActivityType::ENUM_COUNT; -const std::array activityTypes(); - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ClientInterface.h b/plugins/tensorboard-plugins/libkineto/include/ClientInterface.h deleted file mode 100644 index 06dc075838164f80e9481b34a5d5d3c136b92efd..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ClientInterface.h +++ /dev/null @@ -1,16 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -namespace libkineto { - -class ClientInterface { - public: - virtual ~ClientInterface() {} - virtual void init() = 0; - virtual void warmup(bool setupOpInputsCollection) = 0; - virtual void start() = 0; - virtual void stop() = 0; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/Config.h b/plugins/tensorboard-plugins/libkineto/include/Config.h deleted file mode 100644 index 040e96c9f75ab3ab768aaebac28f959f12a3ea06..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/Config.h +++ /dev/null @@ -1,433 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include "AbstractConfig.h" -#include "ActivityType.h" - -#include -#include -#include -#include -#include -#include - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -class Config : public AbstractConfig { - public: - Config(); - Config& operator=(const Config&) = delete; - Config(Config&&) = delete; - Config& operator=(Config&&) = delete; - - // Return a full copy including feature config object - std::unique_ptr clone() const { - auto cfg = std::unique_ptr(new Config(*this)); - cloneFeaturesInto(*cfg); - return cfg; - } - - bool handleOption(const std::string& name, std::string& val) override; - - void setClientDefaults() override; - - // Log events to this file - const std::string& eventLogFile() const { - return eventLogFile_; - } - - bool activityProfilerEnabled() const { - return activityProfilerEnabled_ || - activitiesOnDemandTimestamp_.time_since_epoch().count() > 0; - } - - // Log activitiy trace to this file - const std::string& activitiesLogFile() const { - return activitiesLogFile_; - } - - // Log activitiy trace to this url - const std::string& activitiesLogUrl() const { - return activitiesLogUrl_; - } - - void setActivitiesLogUrl(const std::string& url) { - activitiesLogUrl_ = url; - } - - bool activitiesLogToMemory() const { - return activitiesLogToMemory_; - } - - // Is profiling enabled for the given device? - bool eventProfilerEnabledForDevice(uint32_t dev) const { - return 0 != (eventProfilerDeviceMask_ & (1 << dev)); - } - - // Take a sample (read hardware counters) at this frequency. - // This controls how often counters are read - if all counters cannot - // be collected simultaneously then multiple samples are needed to - // collect all requested counters - see multiplex period. - std::chrono::milliseconds samplePeriod() const { - return samplePeriod_; - } - - void setSamplePeriod(std::chrono::milliseconds period) { - samplePeriod_ = period; - } - - // When all requested counters cannot be collected simultaneously, - // counters will be multiplexed at this frequency. - // Multiplexing can have a large performance impact if done frequently. - // To avoid a perf impact, keep this at 1s or above. - std::chrono::milliseconds multiplexPeriod() const { - return multiplexPeriod_; - } - - void setMultiplexPeriod(std::chrono::milliseconds period) { - multiplexPeriod_ = period; - } - - // Report counters at this frequency. Note that several samples can - // be reported each time, see samplesPerReport. - std::chrono::milliseconds reportPeriod() const { - return reportPeriod_; - } - - void setReportPeriod(std::chrono::milliseconds msecs); - - // Number of samples dispatched each report period. - // Must be in the range [1, report period / sample period]. - // In other words, aggregation is supported but not interpolation. - int samplesPerReport() const { - return samplesPerReport_; - } - - void setSamplesPerReport(int count) { - samplesPerReport_ = count; - } - - // The names of events to collect - const std::set& eventNames() const { - return eventNames_; - } - - // Add additional events to be profiled - void addEvents(const std::set& names) { - eventNames_.insert(names.begin(), names.end()); - } - - // The names of metrics to collect - const std::set& metricNames() const { - return metricNames_; - } - - // Add additional metrics to be profiled - void addMetrics(const std::set& names) { - metricNames_.insert(names.begin(), names.end()); - } - - const std::vector& percentiles() const { - return eventReportPercentiles_; - } - - // Profile for this long, then revert to base config - std::chrono::seconds eventProfilerOnDemandDuration() const { - return eventProfilerOnDemandDuration_; - } - - void setEventProfilerOnDemandDuration(std::chrono::seconds duration) { - eventProfilerOnDemandDuration_ = duration; - } - - // Too many event profilers on a single system can overload the driver. - // At some point, latencies shoot through the roof and collection of samples - // becomes impossible. To avoid this situation we have a limit of profilers - // per GPU. - // NOTE: Communication with a daemon is needed for this feature. - // Library must be built with an active DaemonConfigLoader. - int maxEventProfilersPerGpu() const { - return eventProfilerMaxInstancesPerGpu_; - } - - // On Cuda11 we've seen occasional hangs when reprogramming counters - // Monitor profiling threads and report when a thread is not responding - // for a given number of seconds. - // A period of 0 means disable. - std::chrono::seconds eventProfilerHeartbeatMonitorPeriod() const { - return eventProfilerHeartbeatMonitorPeriod_; - } - - // The types of activities selected in the configuration file - const std::set& selectedActivityTypes() const { - return selectedActivityTypes_; - } - - void setSelectedActivityTypes(const std::set& types) { - selectedActivityTypes_ = types; - } - - bool isOpInputsCollectionEnabled() const { - return enableOpInputsCollection_; - } - - // Trace for this long - std::chrono::milliseconds activitiesDuration() const { - return activitiesDuration_; - } - - // Trace for this many iterations, determined by external API - int activitiesRunIterations() const { - return activitiesRunIterations_; - } - - std::chrono::milliseconds activitiesDurationDefault() const; - - void setActivitiesDuration(std::chrono::milliseconds duration) { - activitiesDuration_ = duration; - } - - int activitiesMaxGpuBufferSize() const { - return activitiesMaxGpuBufferSize_; - } - - std::chrono::seconds activitiesWarmupDuration() const { - return activitiesWarmupDuration_; - } - - int activitiesWarmupIterations() const { - return activitiesWarmupIterations_; - } - - // Timestamp at which the profiling to start, requested by the user. - const std::chrono::time_point requestTimestamp() - const { - if (profileStartTime_.time_since_epoch().count()) { - return profileStartTime_; - } - - // TODO(T94634890): Deperecate requestTimestamp - return requestTimestamp_ + maxRequestAge() + activitiesWarmupDuration(); - } - - bool hasProfileStartTime() const { - return requestTimestamp_.time_since_epoch().count() > 0 || - profileStartTime_.time_since_epoch().count() > 0; - } - - int profileStartIteration() const { - return profileStartIteration_; - } - - bool hasProfileStartIteration() const { - return profileStartIteration_ >= 0 && activitiesRunIterations_ > 0; - } - - void setProfileStartIteration(int iter) { - profileStartIteration_ = iter; - } - - int profileStartIterationRoundUp() const { - return profileStartIterationRoundUp_; - } - - // calculate the start iteration accounting for warmup - int startIterationIncludingWarmup() const { - if (!hasProfileStartIteration()) { - return -1; - } - return profileStartIteration_ - activitiesWarmupIterations_; - } - - const std::chrono::seconds maxRequestAge() const; - - // All VLOG* macros will log if the verbose log level is >= - // the verbosity specified for the verbose log message. - // Default value is -1, so messages with log level 0 will log by default. - int verboseLogLevel() const { - return verboseLogLevel_; - } - - // Modules for which verbose logging is enabled. - // If empty, logging is enabled for all modules. - const std::vector& verboseLogModules() const { - return verboseLogModules_; - } - - bool sigUsr2Enabled() const { - return enableSigUsr2_; - } - - bool ipcFabricEnabled() const { - return enableIpcFabric_; - } - - static std::chrono::milliseconds alignUp( - std::chrono::milliseconds duration, - std::chrono::milliseconds alignment) { - duration += alignment; - return duration - (duration % alignment); - } - - std::chrono::time_point - eventProfilerOnDemandStartTime() const { - return eventProfilerOnDemandTimestamp_; - } - - std::chrono::time_point - eventProfilerOnDemandEndTime() const { - return eventProfilerOnDemandTimestamp_ + eventProfilerOnDemandDuration_; - } - - std::chrono::time_point - activityProfilerRequestReceivedTime() const { - return activitiesOnDemandTimestamp_; - } - - // Users may request and set trace id and group trace id. - const std::string& requestTraceID() const { - return requestTraceID_; - } - - void setRequestTraceID(const std::string& tid) { - requestTraceID_ = tid; - } - - const std::string& requestGroupTraceID() const { - return requestGroupTraceID_; - } - - void setRequestGroupTraceID(const std::string& gtid) { - requestGroupTraceID_ = gtid; - } - - void updateActivityProfilerRequestReceivedTime(); - - void printActivityProfilerConfig(std::ostream& s) const override; - - void validate( - const std::chrono::time_point& fallbackProfileStartTime) override; - - static void addConfigFactory( - std::string name, - std::function factory); - - void print(std::ostream& s) const; - - private: - explicit Config(const Config& other) = default; - - AbstractConfig* cloneDerived(AbstractConfig& parent) const override { - // Clone from AbstractConfig not supported - assert(false); - return nullptr; - } - - uint8_t createDeviceMask(const std::string& val); - - // Adds valid activity types from the user defined string list in the - // configuration file - void setActivityTypes(const std::vector& selected_activities); - - // Sets the default activity types to be traced - void selectDefaultActivityTypes() { - // If the user has not specified an activity list, add all types - for (ActivityType t : activityTypes()) { - // Do no enable this by default - // TODO: introduce optional types - if (t != ActivityType::OVERHEAD) { - selectedActivityTypes_.insert(t); - } - } - } - - int verboseLogLevel_; - std::vector verboseLogModules_; - - // Event profiler - // These settings are also supported in on-demand mode - std::chrono::milliseconds samplePeriod_; - std::chrono::milliseconds reportPeriod_; - int samplesPerReport_; - std::set eventNames_; - std::set metricNames_; - - // On-demand duration - std::chrono::seconds eventProfilerOnDemandDuration_; - // Last on-demand request - std::chrono::time_point - eventProfilerOnDemandTimestamp_; - - int eventProfilerMaxInstancesPerGpu_; - - // Monitor whether event profiler threads are stuck - // at this frequency - std::chrono::seconds eventProfilerHeartbeatMonitorPeriod_; - - // These settings can not be changed on-demand - std::string eventLogFile_; - std::vector eventReportPercentiles_ = {5, 25, 50, 75, 95}; - uint8_t eventProfilerDeviceMask_ = ~0; - std::chrono::milliseconds multiplexPeriod_; - - // Activity profiler - bool activityProfilerEnabled_; - std::set selectedActivityTypes_; - - // The activity profiler settings are all on-demand - std::string activitiesLogFile_; - - std::string activitiesLogUrl_; - - // Log activities to memory buffer - bool activitiesLogToMemory_{false}; - - int activitiesMaxGpuBufferSize_; - std::chrono::seconds activitiesWarmupDuration_; - int activitiesWarmupIterations_; - - // Client Interface - // Enable inputs collection when tracing ops - bool enableOpInputsCollection_{true}; - - // Profile for specified iterations and duration - std::chrono::milliseconds activitiesDuration_; - int activitiesRunIterations_; - - // Below are not used - // Use this net name for iteration count - std::string activitiesExternalAPIIterationsTarget_; - // Only profile nets that includes this in the name - std::vector activitiesExternalAPIFilter_; - // Only profile nets with at least this many operators - int activitiesExternalAPINetSizeThreshold_; - // Only profile nets with at least this many GPU operators - int activitiesExternalAPIGpuOpCountThreshold_; - // Last activity profiler request - std::chrono::time_point - activitiesOnDemandTimestamp_; - - // Synchronized start timestamp - std::chrono::time_point profileStartTime_; - // or start iteration - int profileStartIteration_; - int profileStartIterationRoundUp_; - - // DEPRECATED - std::chrono::time_point requestTimestamp_; - - // Enable profiling via SIGUSR2 - bool enableSigUsr2_; - - // Enable IPC Fabric instead of thrift communication - bool enableIpcFabric_; - - // Logger Metadata - std::string requestTraceID_; - std::string requestGroupTraceID_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h b/plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h deleted file mode 100644 index 4272cf1efa4e7613a46c3684270b4e803853345b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h +++ /dev/null @@ -1,125 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include - -#include "ThreadUtil.h" -#include "ITraceActivity.h" -#include "TraceSpan.h" - -namespace libkineto { - -// Link type, used in GenericTraceActivity.flow.type -constexpr unsigned int kLinkFwdBwd = 1; -constexpr unsigned int kLinkAsyncCpuGpu = 2; - -// @lint-ignore-every CLANGTIDY cppcoreguidelines-non-private-member-variables-in-classes -// @lint-ignore-every CLANGTIDY cppcoreguidelines-pro-type-member-init -class GenericTraceActivity : public ITraceActivity { - - public: - GenericTraceActivity() : activityType(ActivityType::ENUM_COUNT), traceSpan_(NULL) {} - - GenericTraceActivity( - const TraceSpan& trace, ActivityType type, const std::string& name) - : activityType(type), activityName(name), traceSpan_(&trace) { - } - - int64_t deviceId() const override { - return device; - } - - int64_t resourceId() const override { - return resource; - } - - int32_t getThreadId() const override { - return threadId; - } - - int64_t timestamp() const override { - return startTime; - } - - int64_t duration() const override { - return endTime - startTime; - } - - int64_t correlationId() const override { - return id; - } - - ActivityType type() const override { - return activityType; - } - - const ITraceActivity* linkedActivity() const override { - return nullptr; - } - - int flowType() const override { - return flow.type; - } - - int flowId() const override { - return flow.id; - } - - bool flowStart() const override { - return flow.start; - } - - const std::string name() const override { - return activityName; - } - - const TraceSpan* traceSpan() const override { - return traceSpan_; - } - - void log(ActivityLogger& logger) const override; - - //Encode client side metadata as a key/value - template - void addMetadata(const std::string& key, const ValType& value) { - metadata_.push_back(fmt::format("\"{}\": {}", key, value)); - } - - void addMetadataQuoted(const std::string& key, const std::string& value) { - metadata_.push_back(fmt::format("\"{}\": \"{}\"", key, value)); - } - - const std::string metadataJson() const override { - return fmt::format("{}", fmt::join(metadata_, ", ")); - } - - virtual ~GenericTraceActivity() {}; - - int64_t startTime{0}; - int64_t endTime{0}; - int32_t id{0}; - int32_t device{0}; - int32_t resource{0}; - int32_t threadId{0}; - ActivityType activityType; - std::string activityName; - struct Flow { - Flow(): id(0), type(0), start(0) {} - // Ids must be unique within each type - uint32_t id : 27; - // Type will be used to connect flows between profilers, as - // well as look up flow information (name etc) - uint32_t type : 4; - uint32_t start : 1; - } flow; - - private: - const TraceSpan* traceSpan_; - std::vector metadata_; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h b/plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h deleted file mode 100644 index f5d4b3fb828a3348d948c6487acc6a9e5a18f836..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h +++ /dev/null @@ -1,104 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include - -#include "Config.h" -#include "GenericTraceActivity.h" - -/* This file includes an abstract base class for an activity profiler - * that can be implemented by multiple tracing agents in the application. - * The high level Kineto profiler can co-ordinate start and end of tracing - * and combine together events from multiple such activity profilers. - */ - -namespace libkineto { - -using namespace KINETO_NAMESPACE; - -#ifdef _MSC_VER -// workaround for the predefined ERROR macro on Windows -#undef ERROR -#endif // _MSC_VER - -enum class TraceStatus { - READY, // Accepting trace requests - WARMUP, // Performing trace warmup - RECORDING, // Actively collecting activities - PROCESSING, // Recording is complete, preparing results - ERROR, // One or more errors (and possibly also warnings) occurred. - WARNING, // One or more warnings occurred. -}; - -/* IActivityProfilerSession: - * an opaque object that can be used by a high level profiler to - * start/stop and return trace events. - */ -class IActivityProfilerSession { - - public: - virtual ~IActivityProfilerSession() {} - - // start the trace collection synchronously - virtual void start() = 0; - - // stop the trace collection synchronously - virtual void stop() = 0; - - TraceStatus status() { - return status_; - } - - // returns list of Trace Activities - virtual std::vector& activities() = 0; - - // returns errors with this trace - virtual std::vector errors() = 0; - - // processes trace activities using logger - virtual void processTrace(ActivityLogger& logger) = 0; - - // XXX define trace formats - // virtual save(string name, TraceFormat format) - - protected: - TraceStatus status_ = TraceStatus::READY; -}; - - -/* Activity Profiler Plugins: - * These allow other frameworks to integrate into Kineto's primariy - * activity profiler. While the primary activity profiler handles - * timing the trace collections and correlating events the plugins - * can become source of new trace activity types. - */ -class IActivityProfiler { - - public: - - virtual ~IActivityProfiler() {} - - // name of profiler - virtual const std::string& name() const = 0; - - // returns activity types this profiler supports - virtual const std::set& availableActivities() const = 0; - - // Calls prepare() on registered tracer providers passing in the relevant - // activity types. Returns a profiler session handle - virtual std::unique_ptr configure( - const std::set& activity_types, - const Config& config) = 0; - - // asynchronous version of the above with future timestamp and duration. - virtual std::unique_ptr configure( - int64_t ts_ms, - int64_t duration_ms, - const std::set& activity_types, - const Config& config) = 0; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h b/plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h deleted file mode 100644 index 4fce7851b9669ff93a3f3a772140b0466674853c..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h +++ /dev/null @@ -1,50 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -// Stages in libkineto used when pushing logs to UST Logger. -constexpr char kWarmUpStage[] = "Warm Up"; -constexpr char kCollectionStage[] = "Collection"; -constexpr char kPostProcessingStage[] = "Post Processing"; - -#if !USE_GOOGLE_LOG - -#include -#include - -namespace libkineto { - -enum LoggerOutputType { - VERBOSE = 0, - INFO = 1, - WARNING = 2, - ERROR = 3, - STAGE = 4, - ENUM_COUNT = 5 -}; - -const char* toString(LoggerOutputType t); -LoggerOutputType toLoggerOutputType(const std::string& str); - -constexpr int LoggerTypeCount = (int) LoggerOutputType::ENUM_COUNT; - -class ILoggerObserver { - public: - virtual ~ILoggerObserver() = default; - virtual void write(const std::string& message, LoggerOutputType ot) = 0; - virtual const std::map> extractCollectorMetadata() = 0; - virtual void reset() = 0; - virtual void addDevice(const int64_t device) = 0; - virtual void setTraceDurationMS(const int64_t duration) = 0; - virtual void addEventCount(const int64_t count) = 0; - virtual void setTraceID(const std::string&) {} - virtual void setGroupTraceID(const std::string&) {} - virtual void addDestination(const std::string& dest) = 0; - -}; - -} // namespace libkineto - -#endif // !USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h b/plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h deleted file mode 100644 index a477ed814662cb4c57738b7e40ec6052e9f65288..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h +++ /dev/null @@ -1,53 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -#include "ActivityType.h" - -namespace libkineto { - -class ActivityLogger; -struct TraceSpan; - -// Generic activity interface is borrowed from tensorboard protobuf format. -struct ITraceActivity { - virtual ~ITraceActivity() {} - // Device is a physical or logical entity, e.g. CPU, GPU or process - virtual int64_t deviceId() const = 0; - // A resource is something on the device, h/w thread, - // functional units etc. - virtual int64_t resourceId() const = 0; - // s/w thread - virtual int32_t getThreadId() const = 0; - // Start timestamp in mucrosecond - virtual int64_t timestamp() const = 0; - // Duration in microseconds - virtual int64_t duration() const = 0; - // Used to link up async activities - virtual int64_t correlationId() const = 0; - // Part of a flow, identified by flow id and type - virtual int flowType() const = 0; - virtual int flowId() const = 0; - virtual bool flowStart() const = 0; - virtual ActivityType type() const = 0; - virtual const std::string name() const = 0; - // Optional linked activity - virtual const ITraceActivity* linkedActivity() const = 0; - // Optional containing trace object - virtual const TraceSpan* traceSpan() const = 0; - // Log activity - virtual void log(ActivityLogger& logger) const = 0; - // Return json formatted metadata - // FIXME: Return iterator to dynamic type map here instead - virtual const std::string metadataJson() const = 0; - - static int64_t nsToUs(int64_t ns) { - // It's important that this conversion is the same everywhere. - // No rounding! - return ns / 1000; - } -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h b/plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h deleted file mode 100644 index d1dc80ad2ab0dfd3bea313363fb0e6565349889c..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace libkineto { - -int32_t systemThreadId(); -int32_t threadId(); -bool setThreadName(const std::string& name); -std::string getThreadName(); - -int32_t processId(); -std::string processName(int32_t pid); - -// Return a list of pids and process names for the current process -// and its parents. -std::vector> pidCommandPairsOfAncestors(); - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/TraceSpan.h b/plugins/tensorboard-plugins/libkineto/include/TraceSpan.h deleted file mode 100644 index af9a9d5ee556830ac34568e6c81ec4f8f00da2e3..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/TraceSpan.h +++ /dev/null @@ -1,36 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include - -namespace libkineto { - -struct TraceSpan { - TraceSpan() = delete; - TraceSpan( - int64_t startTime, int64_t endTime, std::string name) - : startTime(startTime), endTime(endTime), name(std::move(name)) { - } - TraceSpan( - int opCount, int it, std::string name, std::string prefix) - : opCount(opCount), - iteration(it), - name(std::move(name)), - prefix(std::move(prefix)) { - } - - // FIXME: change to duration? - int64_t startTime{0}; - int64_t endTime{0}; - int opCount{0}; - int iteration{-1}; - // Name is used to identify timeline - std::string name; - // Prefix used to distinguish trace spans on the same timeline - std::string prefix; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/libkineto.h b/plugins/tensorboard-plugins/libkineto/include/libkineto.h deleted file mode 100644 index 87c3d64f638dad9d1c2d24c013135db60d477642..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/libkineto.h +++ /dev/null @@ -1,138 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -// Mediator for initialization and profiler control - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ActivityProfilerInterface.h" -#include "ActivityType.h" -#include "ClientInterface.h" -#include "GenericTraceActivity.h" -#include "TraceSpan.h" -#include "IActivityProfiler.h" -#include "ActivityTraceInterface.h" - -#include "ThreadUtil.h" - -extern "C" { - void suppressLibkinetoLogMessages(); - int InitializeInjection(void); - bool libkineto_init(bool cpuOnly, bool logOnError); -} - -namespace libkineto { - -class Config; -class ConfigLoader; - -struct CpuTraceBuffer { - TraceSpan span{0, 0, "none"}; - int gpuOpCount; - std::deque activities; -}; - -using ChildActivityProfilerFactory = - std::function()>; - -class LibkinetoApi { - public: - - explicit LibkinetoApi(ConfigLoader& configLoader) - : configLoader_(configLoader) { - } - - // Called by client that supports tracing API. - // libkineto can still function without this. - void registerClient(ClientInterface* client); - - // Called by libkineto on init - void registerProfiler(std::unique_ptr profiler) { - activityProfiler_ = std::move(profiler); - initClientIfRegistered(); - } - - ActivityProfilerInterface& activityProfiler() { - return *activityProfiler_; - } - - ClientInterface* client() { - return client_; - } - - void initProfilerIfRegistered() { - static std::once_flag once; - if (activityProfiler_) { - std::call_once(once, [this] { - if (!activityProfiler_->isInitialized()) { - activityProfiler_->init(); - initChildActivityProfilers(); - } - }); - } - } - - bool isProfilerInitialized() const { - return activityProfiler_ && activityProfiler_->isInitialized(); - } - - bool isProfilerRegistered() const { - return activityProfiler_ != nullptr; - } - - void suppressLogMessages() { - suppressLibkinetoLogMessages(); - } - - // Provides access to profier configuration manaegement - ConfigLoader& configLoader() { - return configLoader_; - } - - void registerProfilerFactory( - ChildActivityProfilerFactory factory) { - if (isProfilerInitialized()) { - activityProfiler_->addChildActivityProfiler(factory()); - } else { - childProfilerFactories_.push_back(factory); - } - } - - private: - - void initChildActivityProfilers() { - if (!isProfilerInitialized()) { - return; - } - for (const auto& factory : childProfilerFactories_) { - activityProfiler_->addChildActivityProfiler(factory()); - } - childProfilerFactories_.clear(); - } - - // Client is initialized once both it and libkineto has registered - void initClientIfRegistered(); - - ConfigLoader& configLoader_; - std::unique_ptr activityProfiler_{}; - ClientInterface* client_{}; - int32_t clientRegisterThread_{0}; - - bool isLoaded_{false}; - std::vector childProfilerFactories_; -}; - -// Singleton -LibkinetoApi& api(); - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h b/plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h deleted file mode 100644 index caa6b4d92760d384eca2b1383a679fe7435c53b3..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h +++ /dev/null @@ -1,16 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -namespace libkineto { - -inline int64_t timeSinceEpoch( - const std::chrono::time_point& t) { - return std::chrono::duration_cast( - t.time_since_epoch()) - .count(); -} - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl b/plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl deleted file mode 100644 index 330c54a22dfcedf895f0eba4077713a7c4cd8072..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -def get_libkineto_api_srcs(): - return [ - "src/ThreadUtil.cpp", - "src/libkineto_api.cpp", - ] - -def get_libkineto_cupti_srcs(with_api = True): - return [ - "src/CudaDeviceProperties.cpp", - "src/CuptiActivityApi.cpp", - "src/CuptiActivityPlatform.cpp", - "src/CuptiCallbackApi.cpp", - "src/CuptiEventApi.cpp", - "src/CuptiMetricApi.cpp", - "src/CuptiRangeProfilerApi.cpp", - "src/Demangle.cpp", - "src/EventProfiler.cpp", - "src/EventProfilerController.cpp", - "src/WeakSymbols.cpp", - "src/cupti_strings.cpp", - ] + (get_libkineto_cpu_only_srcs(with_api)) - -def get_libkineto_roctracer_srcs(with_api = True): - return [ - "src/RoctracerActivityApi.cpp", - ] + (get_libkineto_cpu_only_srcs(with_api)) - -def get_libkineto_cpu_only_srcs(with_api = True): - return [ - "src/AbstractConfig.cpp", - "src/CuptiActivityProfiler.cpp", - "src/ActivityProfilerController.cpp", - "src/ActivityProfilerProxy.cpp", - "src/ActivityType.cpp", - "src/Config.cpp", - "src/ConfigLoader.cpp", - "src/CuptiActivityApi.cpp", - "src/Demangle.cpp", - "src/GenericTraceActivity.cpp", - "src/ILoggerObserver.cpp", - "src/Logger.cpp", - "src/init.cpp", - "src/output_csv.cpp", - "src/output_json.cpp", - ] + (get_libkineto_api_srcs() if with_api else []) - -def get_libkineto_public_headers(): - return [ - "include/AbstractConfig.h", - "include/ActivityProfilerInterface.h", - "include/ActivityType.h", - "include/Config.h", - "include/ClientInterface.h", - "include/GenericTraceActivity.h", - "include/GenericTraceActivity.h", - "include/IActivityProfiler.h", - "include/ILoggerObserver.h", - "include/ITraceActivity.h", - "include/TraceSpan.h", - "include/ThreadUtil.h", - "include/libkineto.h", - "include/time_since_epoch.h", - ] - -# kineto code should be updated to not have to -# suppress these warnings. -KINETO_COMPILER_FLAGS = [ - "-fexceptions", - "-Wno-deprecated-declarations", - "-Wno-unused-function", - "-Wno-unused-private-field", -] diff --git a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp deleted file mode 100644 index 780047912ed09996d3952901267d46aab99cf78c..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include -#include - -#include -#include - -#include "kineto/libkineto/sample_programs/kineto_playground.cuh" - -using namespace kineto; - -static const std::string kFileName = "/tmp/kineto_playground_trace.json"; - -int main() { - warmup(); - - // Kineto config - - // Empty types set defaults to all types - std::set types; - - auto& profiler = libkineto::api().activityProfiler(); - libkineto::api().initProfilerIfRegistered(); - profiler.prepareTrace(types); - - // Good to warm up after prepareTrace to get cupti initialization to settle - warmup(); - profiler.startTrace(); - playground(); - - auto trace = profiler.stopTrace(); - LOG(INFO) << "Stopped and processed trace. Got " << trace->activities()->size() << " activities."; - trace->save(kFileName); - return 0; -} - diff --git a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu deleted file mode 100644 index 54c6f82ff4be2e468c0e868b49b3a9130de97490..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu +++ /dev/null @@ -1,60 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include - -#include "kineto_playground.cuh" - - -namespace kineto { - -void warmup(void) { - // Inititalizing CUDA can take a while which we normally do not want to see in Kineto traces. - // This is done in various ways that take Kineto as dependency. This is our way of doing warmup - // for kineto_playground - size_t bytes = 1000; - float* mem = NULL; - auto error = cudaMalloc(&mem, bytes); - if (error != cudaSuccess) { - printf("cudaMalloc failed during kineto_playground warmup. error code: %d", error); - return; - } - - cudaFree(mem); -} - -void basicMemcpyMemset(void) { - size_t size = (1 << 8) * sizeof(float); - float *hostMemSrc, *deviceMem, *hostMemDst; - cudaError_t err; - - hostMemSrc = (float*)malloc(size); - hostMemDst = (float*)malloc(size); - err = cudaMalloc(&deviceMem, size); - if (err != cudaSuccess) { - printf("cudaMalloc failed during %s", __func__); - return; - } - - memset(hostMemSrc, 1, size); - cudaMemcpy(deviceMem, hostMemSrc, size, cudaMemcpyHostToDevice); - if (err != cudaSuccess) { - printf("cudaMemcpy failed during %s", __func__); - return; - } - - cudaMemcpy(hostMemDst, deviceMem, size, cudaMemcpyDeviceToHost); - if (err != cudaSuccess) { - printf("cudaMemcpy failed during %s", __func__); - return; - } - - free(hostMemSrc); - free(hostMemDst); - cudaFree(deviceMem); -} - -void playground(void) { - // Add your experimental CUDA implementation here. -} - -} diff --git a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh b/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh deleted file mode 100644 index 54e1ee59ada9ae88370b38146567ed87be2b914b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh +++ /dev/null @@ -1,18 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -namespace kineto { - -// Warms up CUDA before the tracing starts -void warmup(void); - -// Basic usage of cudaMemcpy and cudaMemset -void basicMemcpyMemset(void); - -// Your experimental code goes in here! -void playground(void); - -} diff --git a/plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp b/plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp deleted file mode 100644 index d60ab43c9a3e198167beb7987d619b0bb8e9ed13..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp +++ /dev/null @@ -1,188 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "AbstractConfig.h" - -#include -#include -#include - -#include "Logger.h" - -using namespace std::chrono; - -using std::string; -using std::vector; - -namespace KINETO_NAMESPACE { - -constexpr char kWhitespace[] = "\t\n "; - -static bool isWhitespace(string& s) { - return s.find_first_not_of(kWhitespace) == string::npos; -} - -// Remove whitespace from both end of string -static inline string trim(string& s) { - if (s.empty()) { - return s; - } else if (isWhitespace(s)) { - return ""; - } - auto start = s.find_first_not_of(kWhitespace); - auto end = s.find_last_not_of(kWhitespace); - return s.substr(start, end - start + 1); -} - -// Helper function for split. -// Return the index of char d in string s. -// If not found, returns the length of the string. -static int find(const char* s, char delim) { - int i; - for (i = 0; s[i]; i++) { - if (s[i] == delim) { - break; - } - } - return i; -} - -// Split a string by delimiter -static vector split(const string& s, char delim) { - vector res; - const char* cs = s.c_str(); - for (int i = find(cs, delim); cs[i]; cs += i + 1, i = find(cs, delim)) { - res.emplace_back(cs, i); - } - res.emplace_back(cs); - return res; -} - -// Remove a trailing comment. -static inline string stripComment(const string& s) { - std::size_t pos = s.find("#"); - return s.substr(0, pos); -} - -string AbstractConfig::toLower(string& s) const { - string res = s; - for (int i = 0; i < res.size(); i++) { - if (res[i] >= 'A' && res[i] <= 'Z') { - res[i] += ('a' - 'A'); - } - } - return res; -} - -bool AbstractConfig::endsWith(const string& s, const string& suffix) const { - if (suffix.size() > s.size()) { - return false; - } - return s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; -} - -vector AbstractConfig::splitAndTrim(const string& s, char delim) const { - auto res = split(s, delim); - for (string& x : res) { - x = trim(x); - } - return res; -} - -int64_t AbstractConfig::toIntRange(const string& val, int64_t min, int64_t max) - const { - char* invalid; - int64_t res = strtoll(val.c_str(), &invalid, 10); - if (val.empty() || *invalid) { - throw std::invalid_argument(fmt::format("Invalid integer: {}", val)); - } else if (res < min || res > max) { - throw std::invalid_argument(fmt::format( - "Invalid argument: {} - expected range [{}, {}]", res, min, max)); - } - return res; -} - -int32_t AbstractConfig::toInt32(const string& val) const { - return toIntRange(val, 0, ~0u / 2); -} - -int64_t AbstractConfig::toInt64(const string& val) const { - return toIntRange(val, 0, ~0ul / 2); -} - -bool AbstractConfig::toBool(string& val) const { - const std::array bool_vals{ - "n", "y", "no", "yes", "f", "t", "false", "true"}; - const string lower_val = toLower(val); - for (int i = 0; i < bool_vals.size(); i++) { - if (lower_val == bool_vals[i]) { - return i % 2; - } - } - throw std::invalid_argument(fmt::format("Invalid bool argument: {}", val)); - return false; -} - -bool AbstractConfig::parse(const string& conf) { - std::istringstream iss(conf); - string line; - - timestamp_ = system_clock::now(); - - // Read the string stream 1 line at a time to parse. - while (std::getline(iss, line)) { - line = stripComment(line); - if (isWhitespace(line)) { - continue; - } - vector key_val = splitAndTrim(line, '='); - if (key_val.size() != 2) { - LOG(ERROR) << "Invalid config line: " << line; - return false; - } else { - bool handled = false; - try { - handled = handleOption(key_val[0], key_val[1]); - if (!handled) { - for (auto& feature_cfg : featureConfigs_) { - if (feature_cfg.second->handleOption(key_val[0], key_val[1])) { - handled = true; - break; - } - } - } - } catch (const std::exception& e) { - LOG(ERROR) << "Failed to parse config line: " << line; - LOG(ERROR) << e.what(); - return false; - } - if (!handled) { - // This might be due to using a newer config option on an - // older binary where it is not supported. In this case, - // print a warning message - but it is expected to work! - LOG(WARNING) << "Unrecognized config line: " << line; - } - } - } - - validate(timestamp_); - - // Store original text, used to detect updates - source_ = conf; - timestamp_ = system_clock::now(); - return true; -} - -bool AbstractConfig::handleOption( - const std::string& /* unused */, - std::string& /* unused */) { - LOG(ERROR) << "handleOption unimplemented"; - return false; -} - -void AbstractConfig::printActivityProfilerConfig(std::ostream& s) const { - for (const auto& feature_cfg : featureConfigs_) { - feature_cfg.second->printActivityProfilerConfig(s); - } -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h b/plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h deleted file mode 100644 index 157af879379a5f5fc5e274f22604987a97f17af4..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h +++ /dev/null @@ -1,29 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - - -#include -#include - -#include "libkineto.h" -#include "CuptiActivityBuffer.h" - -namespace KINETO_NAMESPACE { - -struct ActivityBuffers { - std::list> cpu; - std::unique_ptr gpu; - - // Add a wrapper object to the underlying struct stored in the buffer - template - const ITraceActivity& addActivityWrapper(const T& act) { - wrappers_.push_back(std::make_unique(act)); - return *wrappers_.back().get(); - } - - private: - std::vector> wrappers_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h b/plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h deleted file mode 100644 index 0d1bf642cd68051e487004d33e19c5eb181e1c41..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h +++ /dev/null @@ -1,60 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace KINETO_NAMESPACE { - -class ActivityLogger; - -class ActivityLoggerFactory { - - public: - using FactoryFunc = - std::function(const std::string& url)>; - - // Add logger factory for a protocol prefix - void addProtocol(const std::string& protocol, FactoryFunc f) { - factories_[tolower(protocol)] = f; - } - - // Create a logger, invoking the factory for the protocol specified in url - std::unique_ptr makeLogger(const std::string& url) const { - std::string protocol = extractProtocol(url); - auto it = factories_.find(tolower(protocol)); - if (it != factories_.end()) { - return it->second(stripProtocol(url)); - } - throw std::invalid_argument(fmt::format( - "No logger registered for the {} protocol prefix", - protocol)); - return nullptr; - } - - private: - static std::string tolower(std::string s) { - std::transform(s.begin(), s.end(), s.begin(), - [](unsigned char c) { return std::tolower(c); } - ); - return s; - } - - static std::string extractProtocol(std::string url) { - return url.substr(0, url.find("://")); - } - - static std::string stripProtocol(std::string url) { - size_t pos = url.find("://"); - return pos == url.npos ? url : url.substr(pos + 3); - } - - std::map factories_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp deleted file mode 100644 index c85d41ed73ff059bcd7ee69c36a0bcc6c3d5c4ca..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp +++ /dev/null @@ -1,246 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "ActivityProfilerController.h" - -#include -#include - -#include "ActivityLoggerFactory.h" -#include "ActivityTrace.h" -#include "CuptiActivityApi.h" -#ifdef HAS_ROCTRACER -#include "RoctracerActivityApi.h" -#endif -#include "ThreadUtil.h" -#include "output_json.h" -#include "output_membuf.h" - -#include "Logger.h" - -using namespace std::chrono; - -namespace KINETO_NAMESPACE { - -constexpr milliseconds kProfilerIntervalMsecs(1000); - -ActivityProfilerController::ActivityProfilerController( - ConfigLoader& configLoader, bool cpuOnly) - : configLoader_(configLoader) { -#ifdef HAS_ROCTRACER - profiler_ = std::make_unique( - RoctracerActivityApi::singleton(), cpuOnly); -#else - profiler_ = std::make_unique( - CuptiActivityApi::singleton(), cpuOnly); -#endif - configLoader_.addHandler(ConfigLoader::ConfigKind::ActivityProfiler, this); -} - -ActivityProfilerController::~ActivityProfilerController() { - configLoader_.removeHandler( - ConfigLoader::ConfigKind::ActivityProfiler, this); - if (profilerThread_) { - // signaling termination of the profiler loop - stopRunloop_ = true; - profilerThread_->join(); - delete profilerThread_; - profilerThread_ = nullptr; - } -} - -static ActivityLoggerFactory initLoggerFactory() { - ActivityLoggerFactory factory; - factory.addProtocol("file", [](const std::string& url) { - return std::unique_ptr(new ChromeTraceLogger(url)); - }); - return factory; -} - -static ActivityLoggerFactory& loggerFactory() { - static ActivityLoggerFactory factory = initLoggerFactory(); - return factory; -} - -void ActivityProfilerController::addLoggerFactory( - const std::string& protocol, ActivityLoggerFactory::FactoryFunc factory) { - loggerFactory().addProtocol(protocol, factory); -} - -static std::unique_ptr makeLogger(const Config& config) { - if (config.activitiesLogToMemory()) { - return std::make_unique(config); - } - return loggerFactory().makeLogger(config.activitiesLogUrl()); -} - -bool ActivityProfilerController::canAcceptConfig() { - return !profiler_->isActive(); -} - -void ActivityProfilerController::acceptConfig(const Config& config) { - VLOG(1) << "acceptConfig"; - if (config.activityProfilerEnabled()) { - scheduleTrace(config); - } -} - -void ActivityProfilerController::profilerLoop() { - setThreadName("Kineto Activity Profiler"); - VLOG(0) << "Entering activity profiler loop"; - - auto now = system_clock::now(); - auto next_wakeup_time = now + kProfilerIntervalMsecs; - - while (!stopRunloop_) { - now = system_clock::now(); - - while (now < next_wakeup_time) { - /* sleep override */ - std::this_thread::sleep_for(next_wakeup_time - now); - now = system_clock::now(); - } - - if (!profiler_->isActive()) { - std::lock_guard lock(asyncConfigLock_); - if (asyncRequestConfig_ - && !asyncRequestConfig_->hasProfileStartIteration()) { - // Note on now + kProfilerIntervalMsecs - // Profiler interval does not align perfectly upto startTime - warmup. Waiting until the next tick - // won't allow sufficient time for the profiler to warm up. So check if we are very close to the warmup time and trigger warmup - if (now + kProfilerIntervalMsecs - >= (asyncRequestConfig_->requestTimestamp() - asyncRequestConfig_->activitiesWarmupDuration())) { - LOG(INFO) << "Received on-demand activity trace request by " - << " profile timestamp = " - << asyncRequestConfig_-> - requestTimestamp().time_since_epoch().count(); - activateConfig(now); - } - } - } - - while (next_wakeup_time < now) { - next_wakeup_time += kProfilerIntervalMsecs; - } - - if (profiler_->isActive()) { - next_wakeup_time = profiler_->performRunLoopStep(now, next_wakeup_time); - VLOG(1) << "Profiler loop: " - << duration_cast(system_clock::now() - now).count() - << "ms"; - } - } - - VLOG(0) << "Exited activity profiling loop"; -} - -void ActivityProfilerController::step() { - int64_t currentIter = ++iterationCount_; - VLOG(0) << "Step called , iteration = " << currentIter; - - // optimization to not take the lock unless necessary - if (asyncRequestConfig_ && !profiler_->isActive()) { - std::lock_guard lock(asyncConfigLock_); - auto startIter = asyncRequestConfig_->startIterationIncludingWarmup(); - - if (asyncRequestConfig_->hasProfileStartIteration() - && currentIter >= startIter) { - LOG(INFO) << "Received on-demand activity trace request by profile" - << " start iteration = " - << asyncRequestConfig_->profileStartIteration() - << " current iteration = " << currentIter; - - if (currentIter > startIter) { - // adjust the start iteration if it is in the past - auto newProfileStart = currentIter + - asyncRequestConfig_->activitiesWarmupIterations(); - LOG(INFO) << "Start iteration updated to " << newProfileStart; - asyncRequestConfig_->setProfileStartIteration(newProfileStart); - } - activateConfig(system_clock::now()); - } - } - - if (profiler_->isActive()) { - auto now = system_clock::now(); - auto next_wakeup_time = now + kProfilerIntervalMsecs; - profiler_->performRunLoopStep(now, next_wakeup_time, currentIter); - } -} - -void ActivityProfilerController::activateConfig( - std::chrono::time_point now) { - logger_ = makeLogger(*asyncRequestConfig_); - profiler_->setLogger(logger_.get()); - profiler_->configure(*asyncRequestConfig_, now); - asyncRequestConfig_ = nullptr; -} - -void ActivityProfilerController::scheduleTrace(const Config& config) { - VLOG(1) << "scheduleTrace"; - if (profiler_->isActive()) { - LOG(ERROR) << "Ignored request - profiler busy"; - return; - } - int64_t currentIter = iterationCount_; - if (config.hasProfileStartIteration() && currentIter < 0) { - LOG(ERROR) << "Ignored profile iteration count based request as " - << "application is not updating iteration count"; - return; - } - std::lock_guard lock(asyncConfigLock_); - asyncRequestConfig_ = config.clone(); - - auto startIter = asyncRequestConfig_->startIterationIncludingWarmup(); - - if (asyncRequestConfig_->hasProfileStartIteration() - && (currentIter > startIter) - && asyncRequestConfig_->profileStartIterationRoundUp() > 0) { - auto newProfileStart - = currentIter + asyncRequestConfig_->activitiesWarmupIterations(); - // round up to nearest multiple - auto divisor = asyncRequestConfig_->profileStartIterationRoundUp(); - auto rem = newProfileStart % divisor; - newProfileStart += ((rem == 0) ? 0 : divisor - rem); - LOG(INFO) << "Rounding up profiler start iteration to : " << newProfileStart; - asyncRequestConfig_->setProfileStartIteration(newProfileStart); - } - - // start a profilerLoop() thread to handle request - if (!profilerThread_) { - profilerThread_ = - new std::thread(&ActivityProfilerController::profilerLoop, this); - } -} - -void ActivityProfilerController::prepareTrace(const Config& config) { - // Requests from ActivityProfilerApi have higher priority than - // requests from other sources (signal, daemon). - // Cancel any ongoing request and refuse new ones. - auto now = system_clock::now(); - if (profiler_->isActive()) { - LOG(WARNING) << "Cancelling current trace request in order to start " - << "higher priority synchronous request"; - if (libkineto::api().client()) { - libkineto::api().client()->stop(); - } - profiler_->stopTrace(now); - profiler_->reset(); - } - - profiler_->configure(config, now); -} - -std::unique_ptr ActivityProfilerController::stopTrace() { - profiler_->stopTrace(std::chrono::system_clock::now()); - auto logger = std::make_unique(profiler_->config()); - profiler_->processTrace(*logger); - profiler_->reset(); - return std::make_unique(std::move(logger), loggerFactory()); -} - -void ActivityProfilerController::addMetadata( - const std::string& key, const std::string& value) { - profiler_->addMetadata(key, value); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h deleted file mode 100644 index 415f107cbed6aab4777c65e9e51d65686002e762..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h +++ /dev/null @@ -1,84 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include - -#include "ActivityLoggerFactory.h" -#include "CuptiActivityProfiler.h" -#include "ActivityProfilerInterface.h" -#include "ActivityTraceInterface.h" -#include "ConfigLoader.h" -#include "CuptiActivityApi.h" - -namespace KINETO_NAMESPACE { - -class Config; - -class ActivityProfilerController : public ConfigLoader::ConfigHandler { - public: - explicit ActivityProfilerController(ConfigLoader& configLoader, bool cpuOnly); - ActivityProfilerController(const ActivityProfilerController&) = delete; - ActivityProfilerController& operator=(const ActivityProfilerController&) = - delete; - - ~ActivityProfilerController(); - - static void addLoggerFactory( - const std::string& protocol, - ActivityLoggerFactory::FactoryFunc factory); - - bool canAcceptConfig() override; - void acceptConfig(const Config& config) override; - - void scheduleTrace(const Config& config); - - void prepareTrace(const Config& config); - - void startTrace() { - profiler_->startTrace(std::chrono::system_clock::now()); - } - - void step(); - - std::unique_ptr stopTrace(); - - bool isActive() { - return profiler_->isActive(); - } - - void transferCpuTrace( - std::unique_ptr cpuTrace) { - return profiler_->transferCpuTrace(std::move(cpuTrace)); - } - - void recordThreadInfo() { - profiler_->recordThreadInfo(); - } - - void addChildActivityProfiler( - std::unique_ptr profiler) { - profiler_->addChildActivityProfiler(std::move(profiler)); - } - - void addMetadata(const std::string& key, const std::string& value); - - private: - void profilerLoop(); - void activateConfig(std::chrono::time_point now); - - std::unique_ptr asyncRequestConfig_; - std::mutex asyncConfigLock_; - std::unique_ptr profiler_; - std::unique_ptr logger_; - std::thread* profilerThread_{nullptr}; - std::atomic_bool stopRunloop_{false}; - std::atomic iterationCount_{-1}; - ConfigLoader& configLoader_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp deleted file mode 100644 index b2d36b7b3abf9c3e0aed838a10e4054a5d292139..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp +++ /dev/null @@ -1,119 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "ActivityProfilerProxy.h" - -#include "ActivityProfilerController.h" -#include "Config.h" -#include "CuptiActivityApi.h" -#include "Logger.h" -#include - -namespace KINETO_NAMESPACE { - -ActivityProfilerProxy::ActivityProfilerProxy( - bool cpuOnly, ConfigLoader& configLoader) - : cpuOnly_(cpuOnly), configLoader_(configLoader) { -} - -ActivityProfilerProxy::~ActivityProfilerProxy() { - delete controller_; -}; - -void ActivityProfilerProxy::init() { - if (!controller_) { - controller_ = new ActivityProfilerController(configLoader_, cpuOnly_); - } -} - -void ActivityProfilerProxy::scheduleTrace(const std::string& configStr) { - Config config; - config.parse(configStr); - controller_->scheduleTrace(config); -} - -void ActivityProfilerProxy::scheduleTrace(const Config& config) { - controller_->scheduleTrace(config); -} - -void ActivityProfilerProxy::prepareTrace( - const std::set& activityTypes, - const std::string& configStr) { - Config config; - bool validate_required = true; - - // allow user provided config to override default options - if (!configStr.empty()) { - if (!config.parse(configStr)) { - LOG(WARNING) << "Failed to parse config : " << configStr; - } - // parse also runs validate - validate_required = false; - } - - config.setClientDefaults(); - config.setSelectedActivityTypes(activityTypes); - - if (validate_required) { - config.validate(std::chrono::system_clock::now()); - } - - controller_->prepareTrace(config); -} - -void ActivityProfilerProxy::startTrace() { - controller_->startTrace(); -} - -std::unique_ptr -ActivityProfilerProxy::stopTrace() { - return controller_->stopTrace(); -} - -void ActivityProfilerProxy::step() { - controller_->step(); -} - -bool ActivityProfilerProxy::isActive() { - return controller_->isActive(); -} - -void ActivityProfilerProxy::pushCorrelationId(uint64_t id) { - CuptiActivityApi::pushCorrelationID(id, - CuptiActivityApi::CorrelationFlowType::Default); -} - -void ActivityProfilerProxy::popCorrelationId() { - CuptiActivityApi::popCorrelationID( - CuptiActivityApi::CorrelationFlowType::Default); -} - -void ActivityProfilerProxy::pushUserCorrelationId(uint64_t id) { - CuptiActivityApi::pushCorrelationID(id, - CuptiActivityApi::CorrelationFlowType::User); -} - -void ActivityProfilerProxy::popUserCorrelationId() { - CuptiActivityApi::popCorrelationID( - CuptiActivityApi::CorrelationFlowType::User); -} - -void ActivityProfilerProxy::transferCpuTrace( - std::unique_ptr traceBuffer) { - controller_->transferCpuTrace(std::move(traceBuffer)); -} - -void ActivityProfilerProxy::addMetadata( - const std::string& key, const std::string& value) { - controller_->addMetadata(key, value); -} - -void ActivityProfilerProxy::recordThreadInfo() { - controller_->recordThreadInfo(); -} - -void ActivityProfilerProxy::addChildActivityProfiler( - std::unique_ptr profiler) { - controller_->addChildActivityProfiler(std::move(profiler)); -} - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h b/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h deleted file mode 100644 index b5cf84b2f1ddb005060fea0927c99fc63d144d99..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h +++ /dev/null @@ -1,73 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include "ActivityProfilerInterface.h" - -#include -#include -#include - -#include "ActivityType.h" -#include "ITraceActivity.h" - -namespace libkineto { - // previous declaration is struct so this one must be too. - struct CpuTraceBuffer; -} - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -class ActivityProfilerController; -class Config; -class ConfigLoader; - -class ActivityProfilerProxy : public ActivityProfilerInterface { - - public: - ActivityProfilerProxy(bool cpuOnly, ConfigLoader& configLoader); - ~ActivityProfilerProxy() override; - - void init() override; - bool isInitialized() override { - return controller_ != nullptr; - } - - bool isActive() override; - - void recordThreadInfo() override; - - void scheduleTrace(const std::string& configStr) override; - void scheduleTrace(const Config& config); - - void prepareTrace( - const std::set& activityTypes, - const std::string& configStr = "") override; - - void startTrace() override; - void step() override; - std::unique_ptr stopTrace() override; - - void pushCorrelationId(uint64_t id) override; - void popCorrelationId() override; - - void pushUserCorrelationId(uint64_t id) override; - void popUserCorrelationId() override; - - void transferCpuTrace( - std::unique_ptr traceBuffer) override; - - void addMetadata(const std::string& key, const std::string& value) override; - - virtual void addChildActivityProfiler( - std::unique_ptr profiler) override; - - private: - bool cpuOnly_{true}; - ConfigLoader& configLoader_; - ActivityProfilerController* controller_{nullptr}; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h b/plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h deleted file mode 100644 index 0be76af08e47c16ebee2ac1d1ad01c4425ff17a5..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h +++ /dev/null @@ -1,45 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include - -#include "ActivityLoggerFactory.h" -#include "ActivityTraceInterface.h" -#include "output_json.h" -#include "output_membuf.h" - -namespace libkineto { - -class ActivityTrace : public ActivityTraceInterface { - public: - ActivityTrace( - std::unique_ptr tmpLogger, - const ActivityLoggerFactory& factory) - : memLogger_(std::move(tmpLogger)), - loggerFactory_(factory) { - } - - const std::vector* activities() override { - return memLogger_->traceActivities(); - }; - - void save(const std::string& url) override { - std::string prefix; - // if no protocol is specified, default to file - if (url.find("://") == url.npos) { - prefix = "file://"; - } - memLogger_->log(*loggerFactory_.makeLogger(prefix + url)); - }; - - private: - // Activities are logged into a buffer - std::unique_ptr memLogger_; - - // Alternative logger used by save() if protocol prefix is specified - const ActivityLoggerFactory& loggerFactory_; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp b/plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp deleted file mode 100644 index 18856b72370abdb6d9cf4309b32be4cae10805de..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "ActivityType.h" - -#include - -namespace libkineto { - -struct ActivityTypeName { - const char* name; - ActivityType type; -}; - -static constexpr std::array map{{ - {"cpu_op", ActivityType::CPU_OP}, - {"user_annotation", ActivityType::USER_ANNOTATION}, - {"gpu_user_Annotation", ActivityType::GPU_USER_ANNOTATION}, - {"gpu_memcpy", ActivityType::GPU_MEMCPY}, - {"gpu_memset", ActivityType::GPU_MEMSET}, - {"kernel", ActivityType::CONCURRENT_KERNEL}, - {"external_correlation", ActivityType::EXTERNAL_CORRELATION}, - {"cuda_runtime", ActivityType::CUDA_RUNTIME}, - {"cuda_profiler_range", ActivityType::CUDA_PROFILER_RANGE}, - {"glow_runtime", ActivityType::GLOW_RUNTIME}, - {"cpu_instant_event", ActivityType::CPU_INSTANT_EVENT}, - {"python_function", ActivityType::PYTHON_FUNCTION}, - {"overhead", ActivityType::OVERHEAD}, - {"ENUM_COUNT", ActivityType::ENUM_COUNT} -}}; - -static constexpr bool matchingOrder(int idx = 0) { - return map[idx].type == ActivityType::ENUM_COUNT || - ((idx == (int) map[idx].type) && matchingOrder(idx + 1)); -} -static_assert(matchingOrder(), "ActivityTypeName map is out of order"); - -const char* toString(ActivityType t) { - return map[(int)t].name; -} - -ActivityType toActivityType(const std::string& str) { - for (int i = 0; i < activityTypeCount; i++) { - if (str == map[i].name) { - return map[i].type; - } - } - throw std::invalid_argument(fmt::format("Invalid activity type: {}", str)); -} - -const std::array activityTypes() { - std::array res; - for (int i = 0; i < activityTypeCount; i++) { - res[i] = map[i].type; - } - return res; -} - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/Config.cpp b/plugins/tensorboard-plugins/libkineto/src/Config.cpp deleted file mode 100644 index 95538840f378e83b2b44161823042c620b34fe93..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/Config.cpp +++ /dev/null @@ -1,473 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "Config.h" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "Logger.h" -#include "ThreadUtil.h" - -using namespace std::chrono; - -using std::string; -using std::vector; - -namespace KINETO_NAMESPACE { - -constexpr milliseconds kDefaultSamplePeriodMsecs(1000); -constexpr milliseconds kDefaultMultiplexPeriodMsecs(1000); -constexpr milliseconds kDefaultActivitiesProfileDurationMSecs(500); -constexpr int kDefaultActivitiesMaxGpuBufferSize(128 * 1024 * 1024); -constexpr seconds kDefaultActivitiesWarmupDurationSecs(5); -constexpr seconds kDefaultBufferUntilWarmup(10); -constexpr seconds kDefaultReportPeriodSecs(1); -constexpr int kDefaultSamplesPerReport(1); -constexpr int kDefaultMaxEventProfilersPerGpu(1); -constexpr int kDefaultEventProfilerHearbeatMonitorPeriod(0); -constexpr seconds kMaxRequestAge(10); - -// Event Profiler -constexpr char kEventsKey[] = "EVENTS"; -constexpr char kMetricsKey[] = "METRICS"; -constexpr char kSamplePeriodKey[] = "SAMPLE_PERIOD_MSECS"; -constexpr char kMultiplexPeriodKey[] = "MULTIPLEX_PERIOD_MSECS"; -constexpr char kReportPeriodKey[] = "REPORT_PERIOD_SECS"; -constexpr char kSamplesPerReportKey[] = "SAMPLES_PER_REPORT"; -constexpr char kEventsLogFileKey[] = "EVENTS_LOG_FILE"; -constexpr char kEventsEnabledDevicesKey[] = "EVENTS_ENABLED_DEVICES"; -constexpr char kOnDemandDurationKey[] = "EVENTS_DURATION_SECS"; -constexpr char kMaxEventProfilersPerGpuKey[] = "MAX_EVENT_PROFILERS_PER_GPU"; -constexpr char kHeartbeatMonitorPeriodKey[] = - "EVENTS_HEARTBEAT_MONITOR_PERIOD_SECS"; - -// Activity Profiler -constexpr char kActivitiesEnabledKey[] = "ACTIVITIES_ENABLED"; -constexpr char kActivityTypesKey[] = "ACTIVITY_TYPES"; -constexpr char kActivitiesLogFileKey[] = "ACTIVITIES_LOG_FILE"; -constexpr char kActivitiesDurationKey[] = "ACTIVITIES_DURATION_SECS"; -constexpr char kActivitiesDurationMsecsKey[] = "ACTIVITIES_DURATION_MSECS"; -constexpr char kActivitiesWarmupDurationSecsKey[] = "ACTIVITIES_WARMUP_PERIOD_SECS"; -constexpr char kActivitiesMaxGpuBufferSizeKey[] = - "ACTIVITIES_MAX_GPU_BUFFER_SIZE_MB"; - -// Client Interface -constexpr char kClientInterfaceEnableOpInputsCollection[] = "CLIENT_INTERFACE_ENABLE_OP_INPUTS_COLLECTION"; - -constexpr char kActivitiesWarmupIterationsKey[] = "ACTIVITIES_WARMUP_ITERATIONS"; -constexpr char kActivitiesIterationsKey[] = "ACTIVITIES_ITERATIONS"; -// Common - -// Client-side timestamp used for synchronized start across hosts for -// distributed workloads. -// Specified in milliseconds Unix time (milliseconds since epoch). -// To use, compute a future timestamp as follows: -// * C++: + duration_cast( -// system_clock::now().time_since_epoch()).count() -// * Python: + int(time.time() * 1000) -// * Bash: $(( + $(date +%s%3N))) -// If used for a tracing request, timestamp must be far enough in the future -// to accommodate ACTIVITIES_WARMUP_PERIOD_SECS as well as any delays in -// propagating the request to the profiler. -// If the request can not be honored, it is up to the profilers to report -// an error somehow - no checks are done at config parse time. -// Note PROFILE_START_ITERATION has higher precedence -constexpr char kProfileStartTimeKey[] = "PROFILE_START_TIME"; -// DEPRECATED - USE PROFILE_START_TIME instead -constexpr char kRequestTimestampKey[] = "REQUEST_TIMESTAMP"; - -// Alternatively if the application supports reporting iterations -// start the profile at specific iteration. If the iteration count -// is >= this value the profile is started immediately. -// A value >= 0 is valid for this config option to take effect. -// Note PROFILE_START_ITERATION will take precedence over PROFILE_START_TIME. -constexpr char kProfileStartIterationKey[] = "PROFILE_START_ITERATION"; - -// Users can also start the profile on an integer multiple of the config -// value PROFILE_START_ITERATION_ROUNDUP. This knob behaves similar to -// PROFILE_START_ITERATION but instead of saying : "start collection trace on -// iteration 500", one can configure it to "start collecting trace on the next -// 100th iteration". -// -// For example, -// PROFILE_START_ITERATION_ROUNDUP = 1000, and the current iteration is 2010 -// The profile will then be collected on the next multiple of 1000 ie. 3000 -// Note PROFILE_START_ITERATION_ROUNDUP will also take precedence over -// PROFILE_START_TIME. -constexpr char kProfileStartIterationRoundUpKey[] - = "PROFILE_START_ITERATION_ROUNDUP"; - -// Enable on-demand trigger via kill -USR2 -// When triggered in this way, /tmp/libkineto.conf will be used as config. -constexpr char kEnableSigUsr2Key[] = "ENABLE_SIGUSR2"; - -// Enable communication through IPC Fabric -// and disable thrift communication with dynolog daemon -constexpr char kEnableIpcFabricKey[] = "ENABLE_IPC_FABRIC"; - -// Verbose log level -// The actual glog is not used and --v and --vmodule has no effect. -// Instead set the verbose level and modules in the config file. -constexpr char kLogVerboseLevelKey[] = "VERBOSE_LOG_LEVEL"; -// By default, all modules will log verbose messages >= verboseLogLevel. -// But to reduce noise we can specify one or more modules of interest. -// A module is a C/C++ object file (source file name), -// Example argument: ActivityProfiler.cpp,output_json.cpp -constexpr char kLogVerboseModulesKey[] = "VERBOSE_LOG_MODULES"; - -// Max devices supported on any system -constexpr uint8_t kMaxDevices = 8; - -namespace { - -struct FactoryMap { - - void addFactory( - std::string name, - std::function factory) { - std::lock_guard lock(lock_); - factories_[name] = factory; - } - - void addFeatureConfigs(Config& cfg) { - std::lock_guard lock(lock_); - for (const auto& p : factories_) { - cfg.addFeature(p.first, p.second(cfg)); - } - } - -// Config factories are shared between objects and since -// config objects can be created by multiple threads, we need a lock. - std::mutex lock_; - std::map> factories_; -}; - -std::shared_ptr configFactories() { - // Ensure this is safe to call during shutdown, even as static - // destructors are invoked. Once factories destructor has been - // invoked, weak_ptr.lock() will return nullptr. - // But calls before that point will have a valid shared_ptr, - // delaying destruction of the underlying FactoryMap. - static auto factories = std::make_shared(); - static std::weak_ptr weak_ptr = factories; - return weak_ptr.lock(); -} - -} // namespace - -void Config::addConfigFactory( - std::string name, - std::function factory) { - auto factories = configFactories(); - if (factories) { - factories->addFactory(name, factory); - } -} - -static string defaultTraceFileName() { - return fmt::format("/tmp/libkineto_activities_{}.json", processId()); -} - -Config::Config() - : verboseLogLevel_(-1), - samplePeriod_(kDefaultSamplePeriodMsecs), - reportPeriod_(duration_cast(kDefaultReportPeriodSecs)), - samplesPerReport_(kDefaultSamplesPerReport), - eventProfilerOnDemandDuration_(seconds(0)), - eventProfilerMaxInstancesPerGpu_(kDefaultMaxEventProfilersPerGpu), - eventProfilerHeartbeatMonitorPeriod_( - kDefaultEventProfilerHearbeatMonitorPeriod), - multiplexPeriod_(kDefaultMultiplexPeriodMsecs), - activityProfilerEnabled_(true), - activitiesLogFile_(defaultTraceFileName()), - activitiesLogUrl_(fmt::format("file://{}", activitiesLogFile_)), - activitiesMaxGpuBufferSize_(kDefaultActivitiesMaxGpuBufferSize), - activitiesWarmupDuration_(kDefaultActivitiesWarmupDurationSecs), - activitiesWarmupIterations_(0), - activitiesDuration_(kDefaultActivitiesProfileDurationMSecs), - activitiesRunIterations_(0), - activitiesOnDemandTimestamp_(milliseconds(0)), - profileStartTime_(milliseconds(0)), - profileStartIteration_(-1), - profileStartIterationRoundUp_(-1), - requestTimestamp_(milliseconds(0)), - enableSigUsr2_(false), - enableIpcFabric_(false) { - auto factories = configFactories(); - if (factories) { - factories->addFeatureConfigs(*this); - } -} - -uint8_t Config::createDeviceMask(const string& val) { - uint8_t res = 0; - for (const auto& d : splitAndTrim(val, ',')) { - res |= 1 << toIntRange(d, 0, kMaxDevices - 1); - } - return res; -} - -const seconds Config::maxRequestAge() const { - return kMaxRequestAge; -} - -static std::string getTimeStr(time_point t) { - std::time_t t_c = system_clock::to_time_t(t); - return fmt::format("{:%H:%M:%S}", fmt::localtime(t_c)); -} - -static time_point handleRequestTimestamp(int64_t ms) { - auto t = time_point(milliseconds(ms)); - auto now = system_clock::now(); - if (t > now) { - throw std::invalid_argument(fmt::format( - "Invalid {}: {} - time is in future", - kRequestTimestampKey, - getTimeStr(t))); - } else if ((now - t) > kMaxRequestAge) { - throw std::invalid_argument(fmt::format( - "Invalid {}: {} - time is more than {}s in the past", - kRequestTimestampKey, - getTimeStr(t), - kMaxRequestAge.count())); - } - return t; -} - -void Config::setActivityTypes( - const std::vector& selected_activities) { - selectedActivityTypes_.clear(); - if (selected_activities.size() > 0) { - for (const auto& activity : selected_activities) { - if (activity == "") { - continue; - } - selectedActivityTypes_.insert(toActivityType(activity)); - } - } -} - -bool Config::handleOption(const std::string& name, std::string& val) { - // Event Profiler - if (!name.compare(kEventsKey)) { - vector event_names = splitAndTrim(val, ','); - eventNames_.insert(event_names.begin(), event_names.end()); - } else if (!name.compare(kMetricsKey)) { - vector metric_names = splitAndTrim(val, ','); - metricNames_.insert(metric_names.begin(), metric_names.end()); - } else if (!name.compare(kSamplePeriodKey)) { - samplePeriod_ = milliseconds(toInt32(val)); - } else if (!name.compare(kMultiplexPeriodKey)) { - multiplexPeriod_ = milliseconds(toInt32(val)); - } else if (!name.compare(kReportPeriodKey)) { - setReportPeriod(seconds(toInt32(val))); - } else if (!name.compare(kSamplesPerReportKey)) { - samplesPerReport_ = toInt32(val); - } else if (!name.compare(kEventsLogFileKey)) { - eventLogFile_ = val; - } else if (!name.compare(kEventsEnabledDevicesKey)) { - eventProfilerDeviceMask_ = createDeviceMask(val); - } else if (!name.compare(kOnDemandDurationKey)) { - eventProfilerOnDemandDuration_ = seconds(toInt32(val)); - eventProfilerOnDemandTimestamp_ = timestamp(); - } else if (!name.compare(kMaxEventProfilersPerGpuKey)) { - eventProfilerMaxInstancesPerGpu_ = toInt32(val); - } else if (!name.compare(kHeartbeatMonitorPeriodKey)) { - eventProfilerHeartbeatMonitorPeriod_ = seconds(toInt32(val)); - } - - // Activity Profiler - else if (!name.compare(kActivitiesDurationKey)) { - activitiesDuration_ = - duration_cast(seconds(toInt32(val))); - activitiesOnDemandTimestamp_ = timestamp(); - } else if (!name.compare(kActivityTypesKey)) { - vector activity_types = splitAndTrim(toLower(val), ','); - setActivityTypes(activity_types); - } else if (!name.compare(kActivitiesDurationMsecsKey)) { - activitiesDuration_ = milliseconds(toInt32(val)); - activitiesOnDemandTimestamp_ = timestamp(); - } else if (!name.compare(kActivitiesIterationsKey)) { - activitiesRunIterations_ = toInt32(val); - activitiesOnDemandTimestamp_ = timestamp(); - } else if (!name.compare(kLogVerboseLevelKey)) { - verboseLogLevel_ = toInt32(val); - } else if (!name.compare(kLogVerboseModulesKey)) { - verboseLogModules_ = splitAndTrim(val, ','); - } else if (!name.compare(kActivitiesEnabledKey)) { - activityProfilerEnabled_ = toBool(val); - } else if (!name.compare(kActivitiesLogFileKey)) { - activitiesLogFile_ = val; - activitiesLogUrl_ = fmt::format("file://{}", val); - activitiesOnDemandTimestamp_ = timestamp(); - } else if (!name.compare(kActivitiesMaxGpuBufferSizeKey)) { - activitiesMaxGpuBufferSize_ = toInt32(val) * 1024 * 1024; - } else if (!name.compare(kActivitiesWarmupDurationSecsKey)) { - activitiesWarmupDuration_ = seconds(toInt32(val)); - } else if (!name.compare(kActivitiesWarmupIterationsKey)) { - activitiesWarmupIterations_ = toInt32(val); - } - - // Client Interface - else if (!name.compare(kClientInterfaceEnableOpInputsCollection)) { - enableOpInputsCollection_ = toBool(val); - } - - // Common - else if (!name.compare(kRequestTimestampKey)) { - VLOG(0) << kRequestTimestampKey - << " has been deprecated - please use " - << kProfileStartTimeKey; - requestTimestamp_ = handleRequestTimestamp(toInt64(val)); - } else if (!name.compare(kProfileStartTimeKey)) { - profileStartTime_ = - time_point(milliseconds(toInt64(val))); - } else if (!name.compare(kProfileStartIterationKey)) { - profileStartIteration_ = toInt32(val); - } else if (!name.compare(kProfileStartIterationRoundUpKey)) { - profileStartIterationRoundUp_ = toInt32(val); - } else if (!name.compare(kEnableSigUsr2Key)) { - enableSigUsr2_ = toBool(val); - } else if (!name.compare(kEnableIpcFabricKey)) { - enableIpcFabric_ = toBool(val); - } else { - return false; - } - return true; -} - -std::chrono::milliseconds Config::activitiesDurationDefault() const { - return kDefaultActivitiesProfileDurationMSecs; -}; - -void Config::updateActivityProfilerRequestReceivedTime() { - activitiesOnDemandTimestamp_ = system_clock::now(); -} - -void Config::setClientDefaults() { - AbstractConfig::setClientDefaults(); - activitiesLogToMemory_ = true; -} - -void Config::validate( - const time_point& fallbackProfileStartTime) { - if (samplePeriod_.count() == 0) { - LOG(WARNING) << "Sample period must be greater than 0, setting to 1ms"; - samplePeriod_ = milliseconds(1); - } - - if (multiplexPeriod_ < samplePeriod_) { - LOG(WARNING) << "Multiplex period can not be smaller " - << "than sample period"; - LOG(WARNING) << "Setting multiplex period to " << samplePeriod_.count() - << "ms"; - multiplexPeriod_ = samplePeriod_; - } - - if ((multiplexPeriod_ % samplePeriod_).count() != 0) { - LOG(WARNING) << "Multiplex period must be a " - << "multiple of sample period"; - multiplexPeriod_ = alignUp(multiplexPeriod_, samplePeriod_); - LOG(WARNING) << "Setting multiplex period to " << multiplexPeriod_.count() - << "ms"; - } - - if ((reportPeriod_ % multiplexPeriod_).count() != 0 || - reportPeriod_.count() == 0) { - LOG(WARNING) << "Report period must be a " - << "multiple of multiplex period"; - reportPeriod_ = alignUp(reportPeriod_, multiplexPeriod_); - LOG(WARNING) << "Setting report period to " << reportPeriod_.count() - << "ms"; - } - - if (samplesPerReport_ < 1) { - LOG(WARNING) << "Samples per report must be in the range " - << "[1, report period / sample period]"; - LOG(WARNING) << "Setting samples per report to 1"; - samplesPerReport_ = 1; - } - - int max_samples_per_report = reportPeriod_ / samplePeriod_; - if (samplesPerReport_ > max_samples_per_report) { - LOG(WARNING) << "Samples per report must be in the range " - << "[1, report period / sample period] ([1, " - << reportPeriod_.count() << "ms / " << samplePeriod_.count() - << "ms = " << max_samples_per_report << "])"; - LOG(WARNING) << "Setting samples per report to " << max_samples_per_report; - samplesPerReport_ = max_samples_per_report; - } - - if (!hasProfileStartTime()) { - VLOG(0) - << "No explicit timestamp has been set. " - << "Defaulting it to now + activitiesWarmupDuration with buffer."; - profileStartTime_ = fallbackProfileStartTime + - activitiesWarmupDuration() + kDefaultBufferUntilWarmup; - } - - if (profileStartIterationRoundUp_ == 0) { - // setting to 0 will mess up modulo arithmetic, set it to -1 so it has no effect - LOG(WARNING) << "Profiler start iteration round up should be >= 1."; - profileStartIterationRoundUp_ = -1; - } - - if (profileStartIterationRoundUp_ > 0 && !hasProfileStartIteration()) { - VLOG(0) << "Setting profiler start iteration to 0 so this config is " - << "triggered via iteration count."; - profileStartIteration_ = 0; - } - - if (selectedActivityTypes_.size() == 0) { - selectDefaultActivityTypes(); - } -} - -void Config::setReportPeriod(milliseconds msecs) { - reportPeriod_ = msecs; -} - -void Config::printActivityProfilerConfig(std::ostream& s) const { - s << "Log file: " << activitiesLogFile() << std::endl; - if (hasProfileStartIteration()) { - s << "Trace start Iteration: " << profileStartIteration() << std::endl; - s << "Trace warmup Iterations: " << activitiesWarmupIterations() << std::endl; - s << "Trace profile Iterations: " << activitiesRunIterations() << std::endl; - if (profileStartIterationRoundUp() > 0) { - s << "Trace start iteration roundup : " << profileStartIterationRoundUp() - << std::endl; - } - } else if (hasProfileStartTime()) { - std::time_t t_c = system_clock::to_time_t(requestTimestamp()); - LOG(INFO) << "Trace start time: " - << fmt::format("{:%Y-%m-%d %H:%M:%S}", fmt::localtime(t_c)); - s << "Trace duration: " << activitiesDuration().count() << "ms" - << std::endl; - s << "Warmup duration: " << activitiesWarmupDuration().count() << "s" - << std::endl; - } - - s << "Max GPU buffer size: " << activitiesMaxGpuBufferSize() / 1024 / 1024 - << "MB" << std::endl; - - std::vector activities; - for (const auto& activity : selectedActivityTypes_) { - activities.push_back(toString(activity)); - } - s << "Enabled activities: " - << fmt::format("{}", fmt::join(activities, ",")) << std::endl; - - AbstractConfig::printActivityProfilerConfig(s); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp b/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp deleted file mode 100644 index 4080b678d371e98757897d4d7726c159887377e1..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp +++ /dev/null @@ -1,300 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "ConfigLoader.h" - -#ifdef __linux__ -#include -#endif - -#include -#include -#include -#include -#include - -#include "DaemonConfigLoader.h" - -#include "Logger.h" - -using namespace std::chrono; -using std::string; - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -constexpr char kConfigFileEnvVar[] = "KINETO_CONFIG"; -#ifdef __linux__ -constexpr char kConfigFile[] = "/etc/libkineto.conf"; -constexpr char kOnDemandConfigFile[] = "/tmp/libkineto.conf"; -#else -constexpr char kConfigFile[] = "libkineto.conf"; -constexpr char kOnDemandConfigFile[] = "libkineto.conf"; -#endif - -constexpr std::chrono::seconds kConfigUpdateIntervalSecs(300); -constexpr std::chrono::seconds kOnDemandConfigUpdateIntervalSecs(5); - -#ifdef __linux__ -static struct sigaction originalUsr2Handler = {}; -#endif - -// Use SIGUSR2 to initiate profiling. -// Look for an on-demand config file. -// If none is found, default to base config. -// Try to not affect existing handlers -static bool hasOriginalSignalHandler() { -#ifdef __linux__ - return originalUsr2Handler.sa_handler != nullptr || - originalUsr2Handler.sa_sigaction != nullptr; -#else - return false; -#endif -} - -static void handle_signal(int signal) { -#ifdef __linux__ - if (signal == SIGUSR2) { - ConfigLoader::instance().handleOnDemandSignal(); - if (hasOriginalSignalHandler()) { - // Invoke original handler and reinstate ours - struct sigaction act; - sigaction(SIGUSR2, &originalUsr2Handler, &act); - raise(SIGUSR2); - sigaction(SIGUSR2, &act, &originalUsr2Handler); - } - } -#endif -} - -static void setupSignalHandler(bool enableSigUsr2) { -#ifdef __linux__ - if (enableSigUsr2) { - struct sigaction act = {}; - act.sa_handler = &handle_signal; - act.sa_flags = SA_NODEFER; - if (sigaction(SIGUSR2, &act, &originalUsr2Handler) < 0) { - PLOG(ERROR) << "Failed to register SIGUSR2 handler"; - } - if (originalUsr2Handler.sa_handler == &handle_signal) { - originalUsr2Handler = {}; - } - } else if (hasOriginalSignalHandler()) { - sigaction(SIGUSR2, &originalUsr2Handler, nullptr); - originalUsr2Handler = {}; - } -#endif -} - -// return an empty string if reading gets any errors. Otherwise a config string. -static std::string readConfigFromConfigFile(const char* filename) { - // Read whole file into a string. - std::ifstream file(filename); - std::string conf; - try { - conf.assign( - std::istreambuf_iterator(file), std::istreambuf_iterator()); - } catch (std::exception& e) { - VLOG(0) << "Error reading " << filename << ": " - << e.what(); - conf = ""; - } - return conf; -} - -static std::function()>& -daemonConfigLoaderFactory() { - static std::function()> factory = nullptr; - return factory; -} - -void ConfigLoader::setDaemonConfigLoaderFactory( - std::function()> factory) { - daemonConfigLoaderFactory() = factory; -} - -ConfigLoader& ConfigLoader::instance() { - static ConfigLoader config_loader; - return config_loader; -} - -// return an empty string if polling gets any errors. Otherwise a config string. -std::string ConfigLoader::readOnDemandConfigFromDaemon( - time_point now) { - if (!daemonConfigLoader_) { - return ""; - } - bool events = canHandlerAcceptConfig(ConfigKind::EventProfiler); - bool activities = canHandlerAcceptConfig(ConfigKind::ActivityProfiler); - return daemonConfigLoader_->readOnDemandConfig(events, activities); -} - -int ConfigLoader::contextCountForGpu(uint32_t device) { - if (!daemonConfigLoader_) { - // FIXME: Throw error? - return 0; - } - return daemonConfigLoader_->gpuContextCount(device); -} - -ConfigLoader::ConfigLoader() - : configUpdateIntervalSecs_(kConfigUpdateIntervalSecs), - onDemandConfigUpdateIntervalSecs_(kOnDemandConfigUpdateIntervalSecs), - stopFlag_(false), - onDemandSignal_(false) { -} - -void ConfigLoader::startThread() { - if (!updateThread_) { - // Create default base config here - at this point static initializers - // of extensions should have run and registered all config feature factories - std::lock_guard lock(configLock_); - if (!config_) { - config_ = std::make_unique(); - } - updateThread_ = - std::make_unique(&ConfigLoader::updateConfigThread, this); - } -} - -ConfigLoader::~ConfigLoader() { - if (updateThread_) { - stopFlag_ = true; - { - std::lock_guard lock(updateThreadMutex_); - updateThreadCondVar_.notify_one(); - } - updateThread_->join(); - } -#if !USE_GOOGLE_LOG - Logger::clearLoggerObservers(); -#endif // !USE_GOOGLE_LOG -} - -void ConfigLoader::handleOnDemandSignal() { - onDemandSignal_ = true; - { - std::lock_guard lock(updateThreadMutex_); - updateThreadCondVar_.notify_one(); - } -} - -const char* ConfigLoader::configFileName() { - if (!configFileName_) { - configFileName_ = getenv(kConfigFileEnvVar); - if (configFileName_ == nullptr) { - configFileName_ = kConfigFile; - } - } - return configFileName_; -} - -DaemonConfigLoader* ConfigLoader::daemonConfigLoader() { - if (!daemonConfigLoader_ && daemonConfigLoaderFactory()) { - daemonConfigLoader_ = daemonConfigLoaderFactory()(); - daemonConfigLoader_->setCommunicationFabric(config_->ipcFabricEnabled()); - } - return daemonConfigLoader_.get(); -} - -void ConfigLoader::updateBaseConfig() { - // First try reading local config file - // If that fails, read from daemon - // TODO: Invert these once daemon path fully rolled out - std::string config_str = readConfigFromConfigFile(configFileName()); - if (config_str.empty() && daemonConfigLoader()) { - // If local config file was not successfully loaded (e.g. not found) - // then try the daemon - config_str = daemonConfigLoader()->readBaseConfig(); - } - if (config_str != config_->source()) { - std::lock_guard lock(configLock_); - config_ = std::make_unique(); - config_->parse(config_str); - if (daemonConfigLoader()) { - daemonConfigLoader()->setCommunicationFabric(config_->ipcFabricEnabled()); - } - setupSignalHandler(config_->sigUsr2Enabled()); - SET_LOG_VERBOSITY_LEVEL( - config_->verboseLogLevel(), - config_->verboseLogModules()); - VLOG(0) << "Detected base config change"; - } -} - -void ConfigLoader::configureFromSignal( - time_point now, - Config& config) { - LOG(INFO) << "Received on-demand profiling signal, " - << "reading config from " << kOnDemandConfigFile; - // Reset start time to 0 in order to compute new default start time - const std::string config_str = "PROFILE_START_TIME=0\n" - + readConfigFromConfigFile(kOnDemandConfigFile); - config.parse(config_str); - config.setSignalDefaults(); - notifyHandlers(config); -} - -void ConfigLoader::configureFromDaemon( - time_point now, - Config& config) { - const std::string config_str = readOnDemandConfigFromDaemon(now); - if (config_str.empty()) { - return; - } - - LOG(INFO) << "Received config from dyno:\n" << config_str; - config.parse(config_str); - notifyHandlers(config); -} - -void ConfigLoader::updateConfigThread() { - auto now = system_clock::now(); - auto next_config_load_time = now; - auto next_on_demand_load_time = now + onDemandConfigUpdateIntervalSecs_; - seconds interval = configUpdateIntervalSecs_; - if (interval > onDemandConfigUpdateIntervalSecs_) { - interval = onDemandConfigUpdateIntervalSecs_; - } - auto onDemandConfig = std::make_unique(); - - // This can potentially sleep for long periods of time, so allow - // the desctructor to wake it to avoid a 5-minute long destruct period. - for (;;) { - { - std::unique_lock lock(updateThreadMutex_); - updateThreadCondVar_.wait_for(lock, interval); - } - if (stopFlag_) { - break; - } - now = system_clock::now(); - if (now > next_config_load_time) { - updateBaseConfig(); - next_config_load_time = now + configUpdateIntervalSecs_; - } - if (onDemandSignal_.exchange(false)) { - onDemandConfig = config_->clone(); - configureFromSignal(now, *onDemandConfig); - } else if (now > next_on_demand_load_time) { - onDemandConfig = std::make_unique(); - configureFromDaemon(now, *onDemandConfig); - next_on_demand_load_time = now + onDemandConfigUpdateIntervalSecs_; - } - if (onDemandConfig->verboseLogLevel() >= 0) { - LOG(INFO) << "Setting verbose level to " - << onDemandConfig->verboseLogLevel() - << " from on-demand config"; - SET_LOG_VERBOSITY_LEVEL( - onDemandConfig->verboseLogLevel(), - onDemandConfig->verboseLogModules()); - } - } -} - -bool ConfigLoader::hasNewConfig(const Config& oldConfig) { - std::lock_guard lock(configLock_); - return config_->timestamp() > oldConfig.timestamp(); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h b/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h deleted file mode 100644 index 4ce3468e48db116b2a40d992f000a3af1338e70a..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h +++ /dev/null @@ -1,147 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "Config.h" - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "ILoggerObserver.h" - -namespace libkineto { - class LibkinetoApi; -} - -namespace KINETO_NAMESPACE { - -using namespace libkineto; -class DaemonConfigLoader; - -class ConfigLoader { - public: - - static ConfigLoader& instance(); - - enum ConfigKind { - ActivityProfiler = 0, - EventProfiler, - NumConfigKinds - }; - - struct ConfigHandler { - virtual ~ConfigHandler() {} - virtual bool canAcceptConfig() = 0; - virtual void acceptConfig(const Config& cfg) = 0; - }; - - void addHandler(ConfigKind kind, ConfigHandler* handler) { - std::lock_guard lock(updateThreadMutex_); - handlers_[kind].push_back(handler); - startThread(); - } - - void removeHandler(ConfigKind kind, ConfigHandler* handler) { - std::lock_guard lock(updateThreadMutex_); - auto it = std::find( - handlers_[kind].begin(), handlers_[kind].end(), handler); - if (it != handlers_[kind].end()) { - handlers_[kind].erase(it); - } - } - - void notifyHandlers(const Config& cfg) { - std::lock_guard lock(updateThreadMutex_); - for (auto& key_val : handlers_) { - for (ConfigHandler* handler : key_val.second) { - handler->acceptConfig(cfg); - } - } - } - - bool canHandlerAcceptConfig(ConfigKind kind) { - std::lock_guard lock(updateThreadMutex_); - for (ConfigHandler* handler : handlers_[kind]) { - if (!handler->canAcceptConfig()) { - return false; - } - } - return true; - } - - void initBaseConfig() { - bool init = false; - { - std::lock_guard lock(configLock_); - init = !config_ || config_->source().empty(); - } - if (init) { - updateBaseConfig(); - } - } - - inline std::unique_ptr getConfigCopy() { - std::lock_guard lock(configLock_); - return config_->clone(); - } - - bool hasNewConfig(const Config& oldConfig); - int contextCountForGpu(uint32_t gpu); - - void handleOnDemandSignal(); - - static void setDaemonConfigLoaderFactory( - std::function()> factory); - - private: - ConfigLoader(); - ~ConfigLoader(); - - const char* configFileName(); - DaemonConfigLoader* daemonConfigLoader(); - - void startThread(); - void updateConfigThread(); - void updateBaseConfig(); - - // Create configuration when receiving SIGUSR2 - void configureFromSignal( - std::chrono::time_point now, - Config& config); - - // Create configuration when receiving request from a daemon - void configureFromDaemon( - std::chrono::time_point now, - Config& config); - - std::string readOnDemandConfigFromDaemon( - std::chrono::time_point now); - - std::mutex configLock_; - std::atomic configFileName_{nullptr}; - std::unique_ptr config_; - std::unique_ptr daemonConfigLoader_; - std::map> handlers_; - - std::chrono::seconds configUpdateIntervalSecs_; - std::chrono::seconds onDemandConfigUpdateIntervalSecs_; - std::unique_ptr updateThread_; - std::condition_variable updateThreadCondVar_; - std::mutex updateThreadMutex_; - std::atomic_bool stopFlag_{false}; - std::atomic_bool onDemandSignal_{false}; - -#if !USE_GOOGLE_LOG - std::unique_ptr> loggerObservers_; - std::mutex loggerObserversMutex_; -#endif // !USE_GOOGLE_LOG -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp b/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp deleted file mode 100644 index 1e909d5f9cfda13b95cc4abab547d964fe47b48a..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) Kineto Contributors - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "CudaDeviceProperties.h" - -#include -#include - -#include -#include - -#include "Logger.h" - -namespace KINETO_NAMESPACE { - -static const std::vector createDeviceProps() { - std::vector props; - int device_count; - cudaError_t error_id = cudaGetDeviceCount(&device_count); - // Return empty vector if error. - if (error_id != cudaSuccess) { - LOG(ERROR) << "cudaGetDeviceCount failed with code " << error_id; - return {}; - } - VLOG(0) << "Device count is " << device_count; - for (size_t i = 0; i < device_count; ++i) { - cudaDeviceProp prop; - error_id = cudaGetDeviceProperties(&prop, i); - // Return empty vector if any device property fail to get. - if (error_id != cudaSuccess) { - LOG(ERROR) << "cudaGetDeviceProperties failed with " << error_id; - return {}; - } - props.push_back(prop); - LOGGER_OBSERVER_ADD_DEVICE(i); - } - return props; -} - -static const std::vector& deviceProps() { - static const std::vector props = createDeviceProps(); - return props; -} - -static const std::string createDevicePropertiesJson( - size_t id, const cudaDeviceProp& props) { - return fmt::format(R"JSON( - {{ - "id": {}, "name": "{}", "totalGlobalMem": {}, - "computeMajor": {}, "computeMinor": {}, - "maxThreadsPerBlock": {}, "maxThreadsPerMultiprocessor": {}, - "regsPerBlock": {}, "regsPerMultiprocessor": {}, "warpSize": {}, - "sharedMemPerBlock": {}, "sharedMemPerMultiprocessor": {}, - "numSms": {}, "sharedMemPerBlockOptin": {} - }})JSON", - id, props.name, props.totalGlobalMem, - props.major, props.minor, - props.maxThreadsPerBlock, props.maxThreadsPerMultiProcessor, - props.regsPerBlock, props.regsPerMultiprocessor, props.warpSize, - props.sharedMemPerBlock, props.sharedMemPerMultiprocessor, - props.multiProcessorCount, props.sharedMemPerBlockOptin); -} - -static const std::string createDevicePropertiesJson() { - std::vector jsonProps; - const auto& props = deviceProps(); - for (size_t i = 0; i < props.size(); i++) { - jsonProps.push_back(createDevicePropertiesJson(i, props[i])); - } - return fmt::format("{}", fmt::join(jsonProps, ",")); -} - -const std::string& devicePropertiesJson() { - static std::string devicePropsJson = createDevicePropertiesJson(); - return devicePropsJson; -} - -int smCount(uint32_t deviceId) { - const std::vector &props = deviceProps(); - return deviceId >= props.size() ? 0 : - props[deviceId].multiProcessorCount; -} - -float kernelOccupancy( - uint32_t deviceId, - uint16_t registersPerThread, - int32_t staticSharedMemory, - int32_t dynamicSharedMemory, - int32_t blockX, - int32_t blockY, - int32_t blockZ, - float blocksPerSm) { - // Calculate occupancy - float occupancy = -1.0; - const std::vector &props = deviceProps(); - if (deviceId < props.size()) { - cudaOccFuncAttributes occFuncAttr; - occFuncAttr.maxThreadsPerBlock = INT_MAX; - occFuncAttr.numRegs = registersPerThread; - occFuncAttr.sharedSizeBytes = staticSharedMemory; - occFuncAttr.partitionedGCConfig = PARTITIONED_GC_OFF; - occFuncAttr.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT; - occFuncAttr.maxDynamicSharedSizeBytes = 0; - const cudaOccDeviceState occDeviceState = {}; - int blockSize = blockX * blockY * blockZ; - size_t dynamicSmemSize = dynamicSharedMemory; - cudaOccResult occ_result; - cudaOccDeviceProp prop(props[deviceId]); - cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor( - &occ_result, &prop, &occFuncAttr, &occDeviceState, - blockSize, dynamicSmemSize); - if (status == CUDA_OCC_SUCCESS) { - if (occ_result.activeBlocksPerMultiprocessor < blocksPerSm) { - blocksPerSm = occ_result.activeBlocksPerMultiprocessor; - } - occupancy = blocksPerSm * blockSize / - (float) props[deviceId].maxThreadsPerMultiProcessor; - } else { - LOG_EVERY_N(ERROR, 1000) << "Failed to calculate occupancy, status = " - << status; - } - } - return occupancy; -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h b/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h deleted file mode 100644 index b731fde0c2aab4c9bd3e97f475d204dad02986e7..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) Kineto Contributors - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -namespace KINETO_NAMESPACE { - -int smCount(uint32_t deviceId); - -// Return estimated achieved occupancy for a kernel -float kernelOccupancy( - uint32_t deviceId, - uint16_t registersPerThread, - int32_t staticSharedMemory, - int32_t dynamicSharedMemory, - int32_t blockX, - int32_t blockY, - int32_t blockZ, - float blocks_per_sm); - -// Return compute properties for each device as a json string -const std::string& devicePropertiesJson(); - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h deleted file mode 100644 index 09c29504060ecbbac609aa2d021ff643f45c143e..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h +++ /dev/null @@ -1,114 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -#include "ITraceActivity.h" -#include "CuptiActivityPlatform.h" -#include "ThreadUtil.h" -#include "cupti_strings.h" - -namespace libkineto { - class ActivityLogger; -} - -namespace KINETO_NAMESPACE { - -using namespace libkineto; -struct TraceSpan; - -// These classes wrap the various CUPTI activity types -// into subclasses of ITraceActivity so that they can all be accessed -// using the ITraceActivity interface and logged via ActivityLogger. - -// Abstract base class, templated on Cupti activity type -template -struct CuptiActivity : public ITraceActivity { - explicit CuptiActivity(const T* activity, const ITraceActivity* linked) - : activity_(*activity), linked_(linked) {} - int64_t timestamp() const override { - return nsToUs(unixEpochTimestamp(activity_.start)); - } - int64_t duration() const override { - return nsToUs(activity_.end - activity_.start); - } - // TODO(T107507796): Deprecate ITraceActivity - int64_t correlationId() const override {return 0;} - int32_t getThreadId() const override {return 0;} - const ITraceActivity* linkedActivity() const override {return linked_;} - int flowType() const override {return kLinkAsyncCpuGpu;} - int flowId() const override {return correlationId();} - const T& raw() const {return activity_;} - const TraceSpan* traceSpan() const override {return nullptr;} - - protected: - const T& activity_; - const ITraceActivity* linked_{nullptr}; -}; - -// CUpti_ActivityAPI - CUDA runtime activities -struct RuntimeActivity : public CuptiActivity { - explicit RuntimeActivity( - const CUpti_ActivityAPI* activity, - const ITraceActivity* linked, - int32_t threadId) - : CuptiActivity(activity, linked), threadId_(threadId) {} - int64_t correlationId() const override {return activity_.correlationId;} - int64_t deviceId() const override {return processId();} - int64_t resourceId() const override {return threadId_;} - ActivityType type() const override {return ActivityType::CUDA_RUNTIME;} - bool flowStart() const override; - const std::string name() const override {return runtimeCbidName(activity_.cbid);} - void log(ActivityLogger& logger) const override; - const std::string metadataJson() const override; - - private: - const int32_t threadId_; -}; - -// CUpti_ActivityAPI - CUDA runtime activities -struct OverheadActivity : public CuptiActivity { - explicit OverheadActivity( - const CUpti_ActivityOverhead* activity, - const ITraceActivity* linked, - int32_t threadId=0) - : CuptiActivity(activity, linked), threadId_(threadId) {} - - int64_t timestamp() const override { - return nsToUs(unixEpochTimestamp(activity_.start)); - } - int64_t duration() const override { - return nsToUs(activity_.end - activity_.start); - } - // TODO: Update this with PID ordering - int64_t deviceId() const override {return -1;} - int64_t resourceId() const override {return threadId_;} - ActivityType type() const override {return ActivityType::OVERHEAD;} - bool flowStart() const override; - const std::string name() const override {return overheadKindString(activity_.overheadKind);} - void log(ActivityLogger& logger) const override; - const std::string metadataJson() const override; - - private: - const int32_t threadId_; -}; - -// Base class for GPU activities. -// Can also be instantiated directly. -template -struct GpuActivity : public CuptiActivity { - explicit GpuActivity(const T* activity, const ITraceActivity* linked) - : CuptiActivity(activity, linked) {} - int64_t correlationId() const override {return raw().correlationId;} - int64_t deviceId() const override {return raw().deviceId;} - int64_t resourceId() const override {return raw().streamId;} - ActivityType type() const override; - bool flowStart() const override {return false;} - const std::string name() const override; - void log(ActivityLogger& logger) const override; - const std::string metadataJson() const override; - const T& raw() const {return CuptiActivity::raw();} -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp deleted file mode 100644 index 1ff2dafe06b0016ce7b904ef4b55e047c69bcc1c..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp +++ /dev/null @@ -1,111 +0,0 @@ - /* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "CuptiActivity.h" - -#include - -#include "Demangle.h" -#include "output_base.h" - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -template<> -inline const std::string GpuActivity::name() const { - return demangle(raw().name); -} - -template<> -inline ActivityType GpuActivity::type() const { - return ActivityType::CONCURRENT_KERNEL; -} - -static inline std::string memcpyName(uint8_t kind, uint8_t src, uint8_t dst) { - return fmt::format( - "Memcpy {} ({} -> {})", - memcpyKindString((CUpti_ActivityMemcpyKind)kind), - memoryKindString((CUpti_ActivityMemoryKind)src), - memoryKindString((CUpti_ActivityMemoryKind)dst)); -} - -template<> -inline ActivityType GpuActivity::type() const { - return ActivityType::GPU_MEMCPY; -} - -template<> -inline const std::string GpuActivity::name() const { - return memcpyName(raw().copyKind, raw().srcKind, raw().dstKind); -} - -template<> -inline ActivityType GpuActivity::type() const { - return ActivityType::GPU_MEMCPY; -} - -template<> -inline const std::string GpuActivity::name() const { - return memcpyName(raw().copyKind, raw().srcKind, raw().dstKind); -} - -template<> -inline const std::string GpuActivity::name() const { - const char* memory_kind = - memoryKindString((CUpti_ActivityMemoryKind)raw().memoryKind); - return fmt::format("Memset ({})", memory_kind); -} - -template<> -inline ActivityType GpuActivity::type() const { - return ActivityType::GPU_MEMSET; -} - -inline void RuntimeActivity::log(ActivityLogger& logger) const { - logger.handleActivity(*this); -} - -inline void OverheadActivity::log(ActivityLogger& logger) const { - logger.handleActivity(*this); -} - -inline bool OverheadActivity::flowStart() const { - return false; -} - -inline const std::string OverheadActivity::metadataJson() const { - return ""; -} - -template -inline void GpuActivity::log(ActivityLogger& logger) const { - logger.handleGpuActivity(*this); -} - -inline bool RuntimeActivity::flowStart() const { - return activity_.cbid == CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 || - (activity_.cbid >= CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 && - activity_.cbid <= CUPTI_RUNTIME_TRACE_CBID_cudaMemset2DAsync_v3020) || - activity_.cbid == - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernel_v9000 || - activity_.cbid == - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000; -} - -inline const std::string RuntimeActivity::metadataJson() const { - return fmt::format(R"JSON( - "cbid": {}, "correlation": {})JSON", - activity_.cbid, activity_.correlationId); -} - -template -inline const std::string GpuActivity::metadataJson() const { - return ""; -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp deleted file mode 100644 index 5718bed2f89b06cc702d1b82976cd42e5fceebd0..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp +++ /dev/null @@ -1,343 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "CuptiActivityApi.h" - -#include -#include - -#include "cupti_call.h" -#include "Logger.h" - -using namespace std::chrono; - -namespace KINETO_NAMESPACE { - -// TODO: do we want this to be configurable? -// Set to 2MB to avoid constantly creating buffers (espeically for networks -// that has many small memcpy such as sparseNN) -// Consider putting this on huge pages? -constexpr size_t kBufSize(2 * 1024 * 1024); - -CuptiActivityApi& CuptiActivityApi::singleton() { - static CuptiActivityApi instance; - return instance; -} - -void CuptiActivityApi::pushCorrelationID(int id, CorrelationFlowType type) { -#ifdef HAS_CUPTI - if (!singleton().externalCorrelationEnabled_) { - return; - } - VLOG(2) << "pushCorrelationID(" << id << ")"; - switch(type) { - case Default: - CUPTI_CALL(cuptiActivityPushExternalCorrelationId( - CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, id)); - break; - case User: - CUPTI_CALL(cuptiActivityPushExternalCorrelationId( - CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, id)); - } -#endif -} - -void CuptiActivityApi::popCorrelationID(CorrelationFlowType type) { -#ifdef HAS_CUPTI - if (!singleton().externalCorrelationEnabled_) { - return; - } - switch(type) { - case Default: - CUPTI_CALL(cuptiActivityPopExternalCorrelationId( - CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0, nullptr)); - break; - case User: - CUPTI_CALL(cuptiActivityPopExternalCorrelationId( - CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, nullptr)); - } -#endif -} - -static int getSMCount() { -#ifdef HAS_CUPTI - // There may be a simpler way to get the number of SMs.... - // Look for domain_d - this has 80 instances on Volta and - // 56 instances on Pascal, corresponding to the number of SMs - // FIXME: This does not work on Turing and later - uint32_t domainCount{0}; - CUPTI_CALL(cuptiDeviceGetNumEventDomains(0, &domainCount)); - std::vector ids(domainCount); - size_t sz = sizeof(CUpti_EventDomainID) * domainCount; - CUPTI_CALL(cuptiDeviceEnumEventDomains(0, &sz, ids.data())); - for (CUpti_EventDomainID id : ids) { - char name[16]; - name[0] = '\0'; - sz = sizeof(name); - CUPTI_CALL(cuptiEventDomainGetAttribute( - id, CUPTI_EVENT_DOMAIN_ATTR_NAME, &sz, name)); - if (strncmp(name, "domain_d", sz) == 0) { - uint32_t count{0}; - sz = sizeof(count); - CUPTI_CALL(cuptiDeviceGetEventDomainAttribute( - 0, id, CUPTI_EVENT_DOMAIN_ATTR_TOTAL_INSTANCE_COUNT, &sz, &count)); - return count; - } - } -#endif - - return -1; -} - -int CuptiActivityApi::smCount() { - static int sm_count = getSMCount(); - return sm_count; -} - -static bool nextActivityRecord( - uint8_t* buffer, - size_t valid_size, - CUpti_Activity*& record) { -#ifdef HAS_CUPTI - CUptiResult status = CUPTI_CALL_NOWARN( - cuptiActivityGetNextRecord(buffer, valid_size, &record)); - if (status != CUPTI_SUCCESS) { - if (status != CUPTI_ERROR_MAX_LIMIT_REACHED) { - CUPTI_CALL(status); - } - record = nullptr; - } -#endif - return record != nullptr; -} - -void CuptiActivityApi::setMaxBufferSize(int size) { - maxGpuBufferCount_ = 1 + size / kBufSize; -} - -void CuptiActivityApi::forceLoadCupti() { -#ifdef HAS_CUPTI - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); -#endif -} - -#ifdef HAS_CUPTI -void CUPTIAPI CuptiActivityApi::bufferRequestedTrampoline( - uint8_t** buffer, - size_t* size, - size_t* maxNumRecords) { - singleton().bufferRequested(buffer, size, maxNumRecords); -} - -void CuptiActivityApi::bufferRequested( - uint8_t** buffer, size_t* size, size_t* maxNumRecords) { - std::lock_guard guard(mutex_); - if (allocatedGpuTraceBuffers_.size() >= maxGpuBufferCount_) { - stopCollection = true; - LOG(WARNING) << "Exceeded max GPU buffer count (" - << allocatedGpuTraceBuffers_.size() - << " > " << maxGpuBufferCount_ - << ") - terminating tracing"; - } - - auto buf = std::make_unique(kBufSize); - *buffer = buf->data(); - *size = kBufSize; - - allocatedGpuTraceBuffers_[*buffer] = std::move(buf); - - *maxNumRecords = 0; -} -#endif - -std::unique_ptr -CuptiActivityApi::activityBuffers() { - { - std::lock_guard guard(mutex_); - if (allocatedGpuTraceBuffers_.empty()) { - return nullptr; - } - } - -#ifdef HAS_CUPTI - VLOG(1) << "Flushing GPU activity buffers"; - time_point t1; - if (VLOG_IS_ON(1)) { - t1 = system_clock::now(); - } - // Can't hold mutex_ during this call, since bufferCompleted - // will be called by libcupti and mutex_ is acquired there. - CUPTI_CALL(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED)); - if (VLOG_IS_ON(1)) { - flushOverhead = - duration_cast(system_clock::now() - t1).count(); - } -#endif - std::lock_guard guard(mutex_); - // Transfer ownership of buffers to caller. A new map is created on-demand. - return std::move(readyGpuTraceBuffers_); -} - -#ifdef HAS_CUPTI -int CuptiActivityApi::processActivitiesForBuffer( - uint8_t* buf, - size_t validSize, - std::function handler) { - int count = 0; - if (buf && validSize) { - CUpti_Activity* record{nullptr}; - while ((nextActivityRecord(buf, validSize, record))) { - handler(record); - ++count; - } - } - return count; -} -#endif - -const std::pair CuptiActivityApi::processActivities( - CuptiActivityBufferMap& buffers, - std::function handler) { - std::pair res{0, 0}; -#ifdef HAS_CUPTI - for (auto& pair : buffers) { - // No lock needed - only accessed from this thread - auto& buf = pair.second; - res.first += processActivitiesForBuffer(buf->data(), buf->size(), handler); - res.second += buf->size(); - } -#endif - return res; -} - -void CuptiActivityApi::clearActivities() { - { - std::lock_guard guard(mutex_); - if (allocatedGpuTraceBuffers_.empty()) { - return; - } - } - // Can't hold mutex_ during this call, since bufferCompleted - // will be called by libcupti and mutex_ is acquired there. -#ifdef HAS_CUPTI - CUPTI_CALL(cuptiActivityFlushAll(0)); -#endif - // FIXME: We might want to make sure we reuse - // the same memory during warmup and tracing. - // Also, try to use the amount of memory required - // for active tracing during warmup. - std::lock_guard guard(mutex_); - // Throw away ready buffers as a result of above flush - readyGpuTraceBuffers_ = nullptr; -} - -#ifdef HAS_CUPTI -void CUPTIAPI CuptiActivityApi::bufferCompletedTrampoline( - CUcontext ctx, - uint32_t streamId, - uint8_t* buffer, - size_t /* unused */, - size_t validSize) { - singleton().bufferCompleted(ctx, streamId, buffer, 0, validSize); -} - -void CuptiActivityApi::bufferCompleted( - CUcontext ctx, - uint32_t streamId, - uint8_t* buffer, - size_t /* unused */, - size_t validSize) { - - std::lock_guard guard(mutex_); - auto it = allocatedGpuTraceBuffers_.find(buffer); - if (it == allocatedGpuTraceBuffers_.end()) { - LOG(ERROR) << "bufferCompleted called with unknown buffer: " - << (void*) buffer; - return; - } - - if (!readyGpuTraceBuffers_) { - readyGpuTraceBuffers_ = std::make_unique(); - } - // Set valid size of buffer before moving to ready map - it->second->setSize(validSize); - (*readyGpuTraceBuffers_)[it->first] = std::move(it->second); - allocatedGpuTraceBuffers_.erase(it); - - // report any records dropped from the queue; to avoid unnecessary cupti - // API calls, we make it report only in verbose mode (it doesn't happen - // often in our testing anyways) - if (VLOG_IS_ON(1)) { - size_t dropped = 0; - CUPTI_CALL(cuptiActivityGetNumDroppedRecords(ctx, streamId, &dropped)); - if (dropped != 0) { - LOG(WARNING) << "Dropped " << dropped << " activity records"; - } - } -} -#endif - -void CuptiActivityApi::enableCuptiActivities( - const std::set& selected_activities) { -#ifdef HAS_CUPTI - static bool registered = false; - if (!registered) { - CUPTI_CALL( - cuptiActivityRegisterCallbacks(bufferRequestedTrampoline, bufferCompletedTrampoline)); - } - - externalCorrelationEnabled_ = false; - for (const auto& activity : selected_activities) { - if (activity == ActivityType::GPU_MEMCPY) { - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMCPY)); - } - if (activity == ActivityType::GPU_MEMSET) { - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_MEMSET)); - } - if (activity == ActivityType::CONCURRENT_KERNEL) { - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); - } - if (activity == ActivityType::EXTERNAL_CORRELATION) { - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION)); - externalCorrelationEnabled_ = true; - } - if (activity == ActivityType::CUDA_RUNTIME) { - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_RUNTIME)); - } - if (activity == ActivityType::OVERHEAD) { - CUPTI_CALL(cuptiActivityEnable(CUPTI_ACTIVITY_KIND_OVERHEAD)); - } - } -#endif - - // Explicitly enabled, so reset this flag if set - stopCollection = false; -} - -void CuptiActivityApi::disableCuptiActivities( - const std::set& selected_activities) { -#ifdef HAS_CUPTI - for (const auto& activity : selected_activities) { - if (activity == ActivityType::GPU_MEMCPY) { - CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMCPY)); - } - if (activity == ActivityType::GPU_MEMSET) { - CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_MEMSET)); - } - if (activity == ActivityType::CONCURRENT_KERNEL) { - CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL)); - } - if (activity == ActivityType::EXTERNAL_CORRELATION) { - CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION)); - } - if (activity == ActivityType::CUDA_RUNTIME) { - CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_RUNTIME)); - } - if (activity == ActivityType::OVERHEAD) { - CUPTI_CALL(cuptiActivityDisable(CUPTI_ACTIVITY_KIND_OVERHEAD)); - } - } - externalCorrelationEnabled_ = false; -#endif -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h deleted file mode 100644 index 92af51ecac9ec99181c4726c3849894de9e32b33..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h +++ /dev/null @@ -1,100 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#ifdef HAS_CUPTI -#include -#endif - -#include "ActivityType.h" -#include "CuptiActivityBuffer.h" - - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -#ifndef HAS_CUPTI -using CUpti_Activity = void; -#endif - -class CuptiActivityApi { - public: - enum CorrelationFlowType { - Default, - User - }; - - CuptiActivityApi() = default; - CuptiActivityApi(const CuptiActivityApi&) = delete; - CuptiActivityApi& operator=(const CuptiActivityApi&) = delete; - - virtual ~CuptiActivityApi() {} - - static CuptiActivityApi& singleton(); - - virtual int smCount(); - static void pushCorrelationID(int id, CorrelationFlowType type); - static void popCorrelationID(CorrelationFlowType type); - - void enableCuptiActivities( - const std::set& selected_activities); - void disableCuptiActivities( - const std::set& selected_activities); - void clearActivities(); - - virtual std::unique_ptr activityBuffers(); - - virtual const std::pair processActivities( - CuptiActivityBufferMap&, - std::function handler); - - void setMaxBufferSize(int size); - - std::atomic_bool stopCollection{false}; - int64_t flushOverhead{0}; - - static void forceLoadCupti(); - - private: -#ifdef HAS_CUPTI - int processActivitiesForBuffer( - uint8_t* buf, - size_t validSize, - std::function handler); - static void CUPTIAPI - bufferRequestedTrampoline(uint8_t** buffer, size_t* size, size_t* maxNumRecords); - static void CUPTIAPI bufferCompletedTrampoline( - CUcontext ctx, - uint32_t streamId, - uint8_t* buffer, - size_t /* unused */, - size_t validSize); -#endif // HAS_CUPTI - - int maxGpuBufferCount_{0}; - CuptiActivityBufferMap allocatedGpuTraceBuffers_; - std::unique_ptr readyGpuTraceBuffers_; - std::mutex mutex_; - bool externalCorrelationEnabled_{false}; - - protected: -#ifdef HAS_CUPTI - void bufferRequested(uint8_t** buffer, size_t* size, size_t* maxNumRecords); - void bufferCompleted( - CUcontext ctx, - uint32_t streamId, - uint8_t* buffer, - size_t /* unused */, - size_t validSize); -#endif -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h deleted file mode 100644 index 1c3fbef62c8d8f42ff5da1718e20315cc1ba95d5..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h +++ /dev/null @@ -1,51 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "ITraceActivity.h" - -namespace KINETO_NAMESPACE { - -class CuptiActivityBuffer { - public: - explicit CuptiActivityBuffer(size_t size) : size_(size) { - buf_.reserve(size); - } - CuptiActivityBuffer() = delete; - CuptiActivityBuffer& operator=(const CuptiActivityBuffer&) = delete; - CuptiActivityBuffer(CuptiActivityBuffer&&) = default; - CuptiActivityBuffer& operator=(CuptiActivityBuffer&&) = default; - - size_t size() const { - return size_; - } - - void setSize(size_t size) { - assert(size <= buf_.capacity()); - size_ = size; - } - - uint8_t* data() { - return buf_.data(); - } - - private: - - std::vector buf_; - size_t size_; - - std::vector> wrappers_; -}; - -using CuptiActivityBufferMap = - std::map>; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp deleted file mode 100644 index fa2ef2f3a8c9cbb7f10567c158d6ee3e8e26eed0..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include - -namespace chrono = std::chrono; - -namespace KINETO_NAMESPACE { - -#ifdef _WIN32 -uint64_t epochs_diff() { - // On Windows, steady_clock wraps the QueryPerformanceCounter function. - // https://docs.microsoft.com/en-us/cpp/standard-library/steady-clock-struct?view=msvc-160 - auto steady = - chrono::time_point_cast(chrono::steady_clock::now()); - auto system = - chrono::time_point_cast(chrono::system_clock::now()); - - auto time_since_unix = system.time_since_epoch().count(); - auto time_since_boot = steady.time_since_epoch().count(); - return time_since_unix - time_since_boot; -} - -uint64_t unixEpochTimestamp(uint64_t ts) { - static uint64_t diff = epochs_diff(); - return ts + diff; -} -#else -uint64_t unixEpochTimestamp(uint64_t ts) { - return ts; -} -#endif // _WIN32 - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h deleted file mode 100644 index 78de8373d5fe391d48edffc897aff6893aa6f54f..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace KINETO_NAMESPACE { - -// cupti's timestamps are platform specific. This function convert the raw -// cupti timestamp to time since unix epoch. So that on different platform, -// correction can work correctly. -uint64_t unixEpochTimestamp(uint64_t ts); - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp deleted file mode 100644 index 97c23ef047d75aff75b56773a20801ce83fb1653..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp +++ /dev/null @@ -1,841 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "CuptiActivityProfiler.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef HAS_CUPTI -#include -#endif - -#include "Config.h" -#include "time_since_epoch.h" -#ifdef HAS_CUPTI -#include "CuptiActivity.h" -#include "CuptiActivity.tpp" -#include "CuptiActivityApi.h" -#endif // HAS_CUPTI -#ifdef HAS_ROCTRACER -#include "RoctracerActivityApi.h" -#endif -#include "output_base.h" - -#include "Logger.h" -#include "ThreadUtil.h" - -using namespace std::chrono; -using namespace libkineto; -using std::string; - -namespace KINETO_NAMESPACE { - -void CuptiActivityProfiler::transferCpuTrace( - std::unique_ptr cpuTrace) { - std::lock_guard guard(mutex_); - const string& trace_name = cpuTrace->span.name; - if (currentRunloopState_ != RunloopState::CollectTrace && - currentRunloopState_ != RunloopState::ProcessTrace) { - VLOG(0) << "Trace collection not in progress - discarding span " - << trace_name; - return; - } - - cpuTrace->span.iteration = iterationCountMap_[trace_name]++; - - VLOG(0) << "Received iteration " << cpuTrace->span.iteration << " of span " - << trace_name << " (" << cpuTrace->activities.size() << " activities / " - << cpuTrace->gpuOpCount << " gpu activities)"; - traceBuffers_->cpu.push_back(std::move(cpuTrace)); -} - -#ifdef HAS_ROCTRACER -CuptiActivityProfiler::CuptiActivityProfiler(RoctracerActivityApi& cupti, bool cpuOnly) -#else -CuptiActivityProfiler::CuptiActivityProfiler(CuptiActivityApi& cupti, bool cpuOnly) -#endif - : cupti_(cupti), - flushOverhead_{0, 0}, - setupOverhead_{0, 0}, - cpuOnly_{cpuOnly}, - currentRunloopState_{RunloopState::WaitForRequest}, - stopCollection_{false} {} - -void CuptiActivityProfiler::processTraceInternal(ActivityLogger& logger) { - LOG(INFO) << "Processing " << traceBuffers_->cpu.size() - << " CPU buffers"; - VLOG(0) << "Profile time range: " << captureWindowStartTime_ << " - " - << captureWindowEndTime_; - logger.handleTraceStart(metadata_); - for (auto& cpu_trace : traceBuffers_->cpu) { - string trace_name = cpu_trace->span.name; - VLOG(0) << "Processing CPU buffer for " << trace_name << " (" - << cpu_trace->span.iteration << ") - " - << cpu_trace->activities.size() << " records"; - VLOG(0) << "Span time range: " << cpu_trace->span.startTime << " - " - << cpu_trace->span.endTime; - processCpuTrace(*cpu_trace, logger); - LOGGER_OBSERVER_ADD_EVENT_COUNT(cpu_trace->activities.size()); - } - -#ifdef HAS_CUPTI - if (!cpuOnly_) { - VLOG(0) << "Retrieving GPU activity buffers"; - traceBuffers_->gpu = cupti_.activityBuffers(); - if (VLOG_IS_ON(1)) { - addOverheadSample(flushOverhead_, cupti_.flushOverhead); - } - if (traceBuffers_->gpu) { - const auto count_and_size = cupti_.processActivities( - *traceBuffers_->gpu, - std::bind(&CuptiActivityProfiler::handleCuptiActivity, this, std::placeholders::_1, &logger)); - LOG(INFO) << "Processed " << count_and_size.first - << " GPU records (" << count_and_size.second << " bytes)"; - LOGGER_OBSERVER_ADD_EVENT_COUNT(count_and_size.first); - } - } -#endif // HAS_CUPTI -#ifdef HAS_ROCTRACER - if (!cpuOnly_) { - VLOG(0) << "Retrieving GPU activity buffers"; - const int count = cupti_.processActivities(logger); - LOG(INFO) << "Processed " << count - << " GPU records"; - LOGGER_OBSERVER_ADD_EVENT_COUNT(count); - } -#endif // HAS_ROCTRACER - - for (const auto& session : sessions_){ - LOG(INFO) << "Processing child profiler trace"; - session->processTrace(logger); - } - - finalizeTrace(*config_, logger); -} - -CuptiActivityProfiler::CpuGpuSpanPair& CuptiActivityProfiler::recordTraceSpan( - TraceSpan& span, int gpuOpCount) { - TraceSpan gpu_span(gpuOpCount, span.iteration, span.name, "GPU: "); - auto& iterations = traceSpans_[span.name]; - iterations.push_back({span, gpu_span}); - return iterations.back(); -} - -void CuptiActivityProfiler::processCpuTrace( - libkineto::CpuTraceBuffer& cpuTrace, - ActivityLogger& logger) { - if (cpuTrace.activities.size() == 0) { - LOG(WARNING) << "CPU trace is empty!"; - return; - } - - CpuGpuSpanPair& span_pair = recordTraceSpan(cpuTrace.span, cpuTrace.gpuOpCount); - TraceSpan& cpu_span = span_pair.first; - for (auto const& act : cpuTrace.activities) { - VLOG(2) << act.correlationId() << ": OP " << act.activityName; - if (config_->selectedActivityTypes().count(act.type())) { - act.log(logger); - } - clientActivityTraceMap_[act.correlationId()] = &span_pair; - activityMap_[act.correlationId()] = &act; - - recordThreadInfo(act.resourceId(), act.getThreadId(), act.deviceId()); - } - logger.handleTraceSpan(cpu_span); -} - -#ifdef HAS_CUPTI -inline void CuptiActivityProfiler::handleCorrelationActivity( - const CUpti_ActivityExternalCorrelation* correlation) { - if (correlation->externalKind == CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM0) { - cpuCorrelationMap_[correlation->correlationId] = correlation->externalId; - } else if (correlation->externalKind == CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1){ - userCorrelationMap_[correlation->correlationId] = correlation->externalId; - } else { - LOG(ERROR) << "Invalid CUpti_ActivityExternalCorrelation sent to handleCuptiActivity"; - } -} -#endif // HAS_CUPTI - -static GenericTraceActivity createUserGpuSpan( - const libkineto::ITraceActivity& cpuTraceActivity, - const libkineto::ITraceActivity& gpuTraceActivity) { - GenericTraceActivity res( - *cpuTraceActivity.traceSpan(), - ActivityType::GPU_USER_ANNOTATION, - cpuTraceActivity.name()); - res.startTime = gpuTraceActivity.timestamp(); - res.device = gpuTraceActivity.deviceId(); - res.resource = gpuTraceActivity.resourceId(); - res.endTime = - gpuTraceActivity.timestamp() + gpuTraceActivity.duration(); - res.id = cpuTraceActivity.correlationId(); - return res; -} - -void CuptiActivityProfiler::GpuUserEventMap::insertOrExtendEvent( - const ITraceActivity& userActivity, - const ITraceActivity& gpuActivity) { - StreamKey key(gpuActivity.deviceId(), gpuActivity.resourceId()); - CorrelationSpanMap& correlationSpanMap = streamSpanMap_[key]; - auto it = correlationSpanMap.find(userActivity.correlationId()); - if (it == correlationSpanMap.end()) { - auto it_success = correlationSpanMap.insert({ - userActivity.correlationId(), createUserGpuSpan(userActivity, gpuActivity) - }); - it = it_success.first; - } - GenericTraceActivity& span = it->second; - if (gpuActivity.timestamp() < span.startTime || span.startTime == 0) { - span.startTime = gpuActivity.timestamp(); - } - int64_t gpu_activity_end = gpuActivity.timestamp() + gpuActivity.duration(); - if (gpu_activity_end > span.endTime) { - span.endTime = gpu_activity_end; - } -} - -const CuptiActivityProfiler::CpuGpuSpanPair& CuptiActivityProfiler::defaultTraceSpan() { - static TraceSpan span(0, 0, "Unknown", ""); - static CpuGpuSpanPair span_pair(span, span); - return span_pair; -} - -void CuptiActivityProfiler::GpuUserEventMap::logEvents(ActivityLogger *logger) { - for (auto const& streamMapPair : streamSpanMap_) { - for (auto const& correlationSpanPair : streamMapPair.second) { - correlationSpanPair.second.log(*logger); - } - } -} - -#ifdef HAS_CUPTI -inline bool CuptiActivityProfiler::outOfRange(const ITraceActivity& act) { - bool out_of_range = act.timestamp() < captureWindowStartTime_ || - (act.timestamp() + act.duration()) > captureWindowEndTime_; - if (out_of_range) { - VLOG(2) << "TraceActivity outside of profiling window: " << act.name() - << " (" << act.timestamp() << " < " << captureWindowStartTime_ << " or " - << (act.timestamp() + act.duration()) << " > " << captureWindowEndTime_; - } - return out_of_range; -} - -inline static bool isBlockListedRuntimeCbid(CUpti_CallbackId cbid) { - // Some CUDA calls that are very frequent and also not very interesting. - // Filter these out to reduce trace size. - if (cbid == CUPTI_RUNTIME_TRACE_CBID_cudaGetDevice_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaSetDevice_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaGetLastError_v3020 || - // Don't care about cudaEvents - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventCreate_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventCreateWithFlags_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventRecord_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventDestroy_v3020 || - cbid == CUPTI_RUNTIME_TRACE_CBID_cudaEventSynchronize_v3020) { - return true; - } - - return false; -} - -void CuptiActivityProfiler::handleRuntimeActivity( - const CUpti_ActivityAPI* activity, - ActivityLogger* logger) { - if (isBlockListedRuntimeCbid(activity->cbid)) { - return; - } - VLOG(2) << activity->correlationId - << ": CUPTI_ACTIVITY_KIND_RUNTIME, cbid=" << activity->cbid - << " tid=" << activity->threadId; - int32_t tid = activity->threadId; - const auto& it = resourceInfo_.find({processId(), tid}); - if (it != resourceInfo_.end()) { - tid = it->second.id; - } - const ITraceActivity* linked = linkedActivity( - activity->correlationId, cpuCorrelationMap_); - const auto& runtime_activity = - traceBuffers_->addActivityWrapper(RuntimeActivity(activity, linked, tid)); - checkTimestampOrder(&runtime_activity); - if (outOfRange(runtime_activity)) { - return; - } - runtime_activity.log(*logger); -} - -void CuptiActivityProfiler::handleOverheadActivity( - const CUpti_ActivityOverhead* activity, - ActivityLogger* logger) { - VLOG(2) << ": CUPTI_ACTIVITY_KIND_OVERHEAD" << " overheadKind=" << activity->overheadKind; - - const auto& overhead_activity = - traceBuffers_->addActivityWrapper(OverheadActivity(activity, nullptr)); - overhead_activity.log(*logger); -} - - -inline void CuptiActivityProfiler::updateGpuNetSpan( - const ITraceActivity& gpuOp) { - if (!gpuOp.linkedActivity()) { - VLOG(0) << "Missing linked activity"; - return; - } - const auto& it = clientActivityTraceMap_.find( - gpuOp.linkedActivity()->correlationId()); - if (it == clientActivityTraceMap_.end()) { - // No correlation id mapping? - return; - } - TraceSpan& gpu_span = it->second->second; - if (gpuOp.timestamp() < gpu_span.startTime || gpu_span.startTime == 0) { - gpu_span.startTime = gpuOp.timestamp(); - } - if ((gpuOp.timestamp() + gpuOp.duration()) > gpu_span.endTime) { - gpu_span.endTime = gpuOp.timestamp() + gpuOp.duration(); - } -} - -// I've observed occasional broken timestamps attached to GPU events... -void CuptiActivityProfiler::checkTimestampOrder(const ITraceActivity* act1) { - // Correlated GPU runtime activity cannot - // have timestamp greater than the GPU activity's - const auto& it = correlatedCudaActivities_.find(act1->correlationId()); - if (it == correlatedCudaActivities_.end()) { - correlatedCudaActivities_.insert({act1->correlationId(), act1}); - return; - } - - // Activities may be appear in the buffers out of order. - // If we have a runtime activity in the map, it should mean that we - // have a GPU activity passed in, and vice versa. - const ITraceActivity* act2 = it->second; - if (act2->type() == ActivityType::CUDA_RUNTIME) { - // Buffer is out-of-order. - // Swap so that runtime activity is first for the comparison below. - std::swap(act1, act2); - } - if (act1->timestamp() > act2->timestamp()) { - LOG(WARNING) << "GPU op timestamp (" << act2->timestamp() - << ") < runtime timestamp (" << act1->timestamp() << ") by " - << act1->timestamp() - act2->timestamp() << "us"; - LOG(WARNING) << "Name: " << act2->name() - << " Device: " << act2->deviceId() - << " Stream: " << act2->resourceId(); - } -} - -inline void CuptiActivityProfiler::handleGpuActivity( - const ITraceActivity& act, - ActivityLogger* logger) { - if (outOfRange(act)) { - return; - } - checkTimestampOrder(&act); - VLOG(2) << act.correlationId() << ": " - << act.name(); - recordStream(act.deviceId(), act.resourceId(), ""); - act.log(*logger); - updateGpuNetSpan(act); - if (config_->selectedActivityTypes().count(ActivityType::GPU_USER_ANNOTATION)) { - const auto& it = userCorrelationMap_.find(act.correlationId()); - if (it != userCorrelationMap_.end()) { - const auto& it2 = activityMap_.find(it->second); - if (it2 != activityMap_.end()) { - recordStream(act.deviceId(), act.resourceId(), "context"); - gpuUserEventMap_.insertOrExtendEvent(*it2->second, act); - } - } - } -} - -const ITraceActivity* CuptiActivityProfiler::linkedActivity( - int32_t correlationId, - const std::unordered_map& correlationMap) { - const auto& it = correlationMap.find(correlationId); - if (it != correlationMap.end()) { - const auto& it2 = activityMap_.find(it->second); - if (it2 != activityMap_.end()) { - return it2->second; - } - } - return nullptr; -} - -template -inline void CuptiActivityProfiler::handleGpuActivity( - const T* act, ActivityLogger* logger) { - const ITraceActivity* linked = linkedActivity( - act->correlationId, cpuCorrelationMap_); - const auto& gpu_activity = - traceBuffers_->addActivityWrapper(GpuActivity(act, linked)); - handleGpuActivity(gpu_activity, logger); -} - -void CuptiActivityProfiler::handleCuptiActivity(const CUpti_Activity* record, ActivityLogger* logger) { - switch (record->kind) { - case CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION: - handleCorrelationActivity( - reinterpret_cast( - record)); - break; - case CUPTI_ACTIVITY_KIND_RUNTIME: - handleRuntimeActivity( - reinterpret_cast(record), logger); - break; - case CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL: - handleGpuActivity( - reinterpret_cast(record), logger); - break; - case CUPTI_ACTIVITY_KIND_MEMCPY: - handleGpuActivity( - reinterpret_cast(record), logger); - break; - case CUPTI_ACTIVITY_KIND_MEMCPY2: - handleGpuActivity( - reinterpret_cast(record), logger); - break; - case CUPTI_ACTIVITY_KIND_MEMSET: - handleGpuActivity( - reinterpret_cast(record), logger); - break; - case CUPTI_ACTIVITY_KIND_OVERHEAD: - handleOverheadActivity (reinterpret_cast(record), logger); - break; - default: - LOG(WARNING) << "Unexpected activity type: " << record->kind; - break; - } -} -#endif // HAS_CUPTI - -void CuptiActivityProfiler::configureChildProfilers() { - // If child profilers are enabled create profiler sessions - for (auto& profiler: profilers_) { - int64_t start_time_ms = duration_cast( - profileStartTime_.time_since_epoch()).count(); - LOG(INFO) << "Running child profiler " << profiler->name() << " for " - << config_->activitiesDuration().count() << " ms"; - auto session = profiler->configure( - start_time_ms, - config_->activitiesDuration().count(), - config_->selectedActivityTypes(), - *config_ - ); - if (session) { - sessions_.push_back(std::move(session)); - } - } -} - -void CuptiActivityProfiler::configure( - const Config& config, - const time_point& now) { - std::lock_guard guard(mutex_); - if (isActive()) { - LOG(ERROR) << "CuptiActivityProfiler already busy, terminating"; - return; - } - - config_ = config.clone(); - - if (config_->activitiesDuration().count() == 0) { - // Use default if not specified - config_->setActivitiesDuration( - config_->activitiesDurationDefault()); - } - - // Ensure we're starting in a clean state - resetTraceData(); - -#if !USE_GOOGLE_LOG - // Add a LoggerObserverCollector to collect all logs during the trace. - loggerCollectorMetadata_ = std::make_unique(); - Logger::addLoggerObserver(loggerCollectorMetadata_.get()); -#endif // !USE_GOOGLE_LOG - - profileStartTime_ = config_->requestTimestamp(); - - if (config_->hasProfileStartIteration()) { - profileStartIter_ = config_->profileStartIteration(); - profileEndIter_ = profileStartIter_ + config_->activitiesRunIterations(); - } else { - - profileStartIter_ = -1; - profileEndIter_ = (std::numeric_limits::max)(); - - if (profileStartTime_ < now) { - LOG(ERROR) << "Not starting tracing - start timestamp is in the past. Time difference (ms): " << duration_cast(now - profileStartTime_).count(); - return; - } else if ((profileStartTime_ - now) < config_->activitiesWarmupDuration()) { - LOG(ERROR) << "Not starting tracing - insufficient time for warmup. Time to warmup (ms): " << duration_cast(profileStartTime_ - now).count() ; - return; - } - } - - if (LOG_IS_ON(INFO)) { - config_->printActivityProfilerConfig(LIBKINETO_DBG_STREAM); - } - if (!cpuOnly_ && !libkineto::api().client()) { - if (profileStartIter_ < 0) { - LOG(INFO) << "GPU-only tracing for " - << config_->activitiesDuration().count() << "ms"; - } else { - LOG(INFO) << "GPU-only tracing for " - << config_->activitiesRunIterations() << " iterations"; - } - } - - // Set useful metadata into the logger. - LOGGER_OBSERVER_SET_TRACE_DURATION_MS(config_->activitiesDuration().count()); - if (!config_->requestTraceID().empty()) { - LOGGER_OBSERVER_SET_TRACE_ID(config_->requestTraceID()); - } - if (!config_->requestGroupTraceID().empty()) { - LOGGER_OBSERVER_SET_GROUP_TRACE_ID(config_->requestGroupTraceID()); - } - LOGGER_OBSERVER_ADD_DESTINATION(config_->activitiesLogUrl()); - -#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) - if (!cpuOnly_) { - // Enabling CUPTI activity tracing incurs a larger perf hit at first, - // presumably because structures are allocated and initialized, callbacks - // are activated etc. After a while the overhead decreases and stabilizes. - // It's therefore useful to perform some warmup before starting recording. - LOG(INFO) << "Enabling GPU tracing"; - cupti_.setMaxBufferSize(config_->activitiesMaxGpuBufferSize()); - - time_point timestamp; - if (VLOG_IS_ON(1)) { - timestamp = system_clock::now(); - } -#ifdef HAS_CUPTI - cupti_.enableCuptiActivities(config_->selectedActivityTypes()); -#else - cupti_.enableActivities(config_->selectedActivityTypes()); -#endif - if (VLOG_IS_ON(1)) { - auto t2 = system_clock::now(); - addOverheadSample( - setupOverhead_, duration_cast(t2 - timestamp).count()); - } - } -#endif // HAS_CUPTI || HAS_ROCTRACER - - if (profilers_.size() > 0) { - configureChildProfilers(); - } - - if (libkineto::api().client()) { - libkineto::api().client()->warmup(config_->isOpInputsCollectionEnabled()); - } - if (profileStartIter_ >= 0) { - LOG(INFO) << "Tracing starting on iteration = " << profileStartIter_; - } else { - LOG(INFO) << "Tracing starting in " - << duration_cast(profileStartTime_ - now).count() << "s"; - } - - traceBuffers_ = std::make_unique(); - captureWindowStartTime_ = captureWindowEndTime_ = 0; - currentRunloopState_ = RunloopState::Warmup; -} - -void CuptiActivityProfiler::startTraceInternal(const time_point& now) { - captureWindowStartTime_ = libkineto::timeSinceEpoch(now); - VLOG(0) << "Warmup -> CollectTrace"; - for (auto& session: sessions_){ - LOG(INFO) << "Starting child profiler session"; - session->start(); - } - currentRunloopState_ = RunloopState::CollectTrace; -} - -void CuptiActivityProfiler::stopTraceInternal(const time_point& now) { - if (captureWindowEndTime_ == 0) { - captureWindowEndTime_ = libkineto::timeSinceEpoch(now); - } -#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) - if (!cpuOnly_) { - time_point timestamp; - if (VLOG_IS_ON(1)) { - timestamp = system_clock::now(); - } -#ifdef HAS_CUPTI - cupti_.disableCuptiActivities(config_->selectedActivityTypes()); -#else - cupti_.disableActivities(config_->selectedActivityTypes()); -#endif - if (VLOG_IS_ON(1)) { - auto t2 = system_clock::now(); - addOverheadSample( - setupOverhead_, duration_cast(t2 - timestamp).count()); - } - } -#endif // HAS_CUPTI || HAS_ROCTRACER - - if (currentRunloopState_ == RunloopState::CollectTrace) { - VLOG(0) << "CollectTrace -> ProcessTrace"; - } else { - LOG(WARNING) << "Called stopTrace with state == " << - static_cast::type>( - currentRunloopState_.load()); - } - for (auto& session: sessions_){ - LOG(INFO) << "Stopping child profiler session"; - session->stop(); - } - currentRunloopState_ = RunloopState::ProcessTrace; -} - -void CuptiActivityProfiler::resetInternal() { - resetTraceData(); - currentRunloopState_ = RunloopState::WaitForRequest; -} - -bool CuptiActivityProfiler::isWarmupDone( - const time_point& now, - int64_t currentIter) const { - // is it a time based config - if (profileStartIter_ < 0) { - // qualify that this check is not being called from application step() API - // this avoids races between the step() API and periodically invoked - // profiler run loop step() method - return (currentIter < 0) && (now >= profileStartTime_); - } - // this is an iteration based config - if (currentIter < 0) { - return false; - } - return currentIter >= profileStartIter_; -} - -bool CuptiActivityProfiler::isCollectionDone( - const time_point& now, - int64_t currentIter) const { - // is it a time based config - if (profileStartIter_ < 0) { - // qualify that this check is not being called from application step() API - return (currentIter < 0) && (now >= profileEndTime_); - } - // this is an iteration based config - if (currentIter < 0) { - return false; - } - return currentIter >= profileEndIter_; -} - -const time_point CuptiActivityProfiler::performRunLoopStep( - const time_point& now, - const time_point& nextWakeupTime, - int64_t currentIter) { - auto new_wakeup_time = nextWakeupTime; - bool warmup_done = false, collection_done = false; - - VLOG_IF(1, currentIter >= 0) << "Run loop on application step(), iteration = " - << currentIter; - - switch (currentRunloopState_) { - case RunloopState::WaitForRequest: - VLOG(1) << "State: WaitForRequest"; - // Nothing to do - break; - - case RunloopState::Warmup: - VLOG(1) << "State: Warmup"; - warmup_done = isWarmupDone(now, currentIter); -#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) - // Flushing can take a while so avoid doing it close to the start time - if (!cpuOnly_ && currentIter < 0 && - (profileStartIter_ >= 0 || nextWakeupTime < profileStartTime_)) { - cupti_.clearActivities(); - } - - if (cupti_.stopCollection) { - // Go to process trace to clear any outstanding buffers etc - LOG(WARNING) << "Trace terminated during warmup"; - std::lock_guard guard(mutex_); - stopTraceInternal(now); - resetInternal(); - VLOG(0) << "Warmup -> WaitForRequest"; - break; - } -#endif // HAS_CUPTI || HAS_ROCTRACER - - if (warmup_done) { - UST_LOGGER_MARK_COMPLETED(kWarmUpStage); - if (profileStartIter_ < 0 && - (now > profileStartTime_ + milliseconds(10))) { - LOG(WARNING) - << "Tracing started " - << duration_cast(now - profileStartTime_).count() - << "ms late!"; - } else { - LOG(INFO) << "Tracing started"; - } - startTrace(now); - if (libkineto::api().client()) { - libkineto::api().client()->start(); - } - if (nextWakeupTime > profileEndTime_) { - new_wakeup_time = profileEndTime_; - } - } else if (nextWakeupTime > profileStartTime_) { - new_wakeup_time = profileStartTime_; - } - - break; - - case RunloopState::CollectTrace: - VLOG(1) << "State: CollectTrace"; - // captureWindowStartTime_ can be set by external threads, - // so recompute end time. - // FIXME: Is this a good idea for synced start? - if (profileStartIter_ < 0) { - std::lock_guard guard(mutex_); - profileEndTime_ = time_point( - microseconds(captureWindowStartTime_)) + - config_->activitiesDuration(); - } - - collection_done = isCollectionDone(now, currentIter); - - // TODO revisit stopCollection_ is not used right now - if (collection_done || stopCollection_.exchange(false) -#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) - || cupti_.stopCollection -#endif // HAS_CUPTI || HAS_ROCTRACER - ){ - // Update runloop state first to prevent further updates to shared state - LOG(INFO) << "Tracing complete."; - if (currentIter > 0) { - LOG(INFO) << "This state change was invoked by application's step() call"; - } - // FIXME: Need to communicate reason for stopping on errors - if (libkineto::api().client()) { - libkineto::api().client()->stop(); - } - std::lock_guard guard(mutex_); - stopTraceInternal(now); - VLOG_IF(0, collection_done) << "Reached profile end time"; - - UST_LOGGER_MARK_COMPLETED(kCollectionStage); - } else if (profileStartIter_ >= 0) { - // nothing to do here - } else if (now < profileEndTime_ && profileEndTime_ < nextWakeupTime) { - new_wakeup_time = profileEndTime_; - } - - break; - - case RunloopState::ProcessTrace: - VLOG(1) << "State: ProcessTrace"; - // skip this state transition if it called from the step() api - // of the profiler. - // else it could lead to a race between the profiler thread and an - // application thread calling step() - if (currentIter >= 0) { - return new_wakeup_time; - } - // FIXME: Probably want to allow interruption here - // for quickly handling trace request via synchronous API - std::lock_guard guard(mutex_); - processTraceInternal(*logger_); - UST_LOGGER_MARK_COMPLETED(kPostProcessingStage); - resetInternal(); - VLOG(0) << "ProcessTrace -> WaitForRequest"; - break; - } - - return new_wakeup_time; -} - -void CuptiActivityProfiler::finalizeTrace(const Config& config, ActivityLogger& logger) { - LOG(INFO) << "Recorded nets:"; - { - for (const auto& it : iterationCountMap_) { - LOG(INFO) << it.first << ": " << it.second << " iterations"; - } - iterationCountMap_.clear(); - } - - // Process names - int32_t pid = processId(); - string process_name = processName(pid); - if (!process_name.empty()) { - logger.handleDeviceInfo( - {pid, process_name, "CPU"}, captureWindowStartTime_); - if (!cpuOnly_) { - // GPU events use device id as pid (0-7). - constexpr int kMaxGpuCount = 8; - for (int gpu = 0; gpu < kMaxGpuCount; gpu++) { - logger.handleDeviceInfo( - {gpu, process_name, fmt::format("GPU {}", gpu)}, - captureWindowStartTime_); - } - } - } - - // Thread & stream info - for (auto pair : resourceInfo_) { - const auto& resource = pair.second; - logger.handleResourceInfo(resource, captureWindowStartTime_); - } - - for (const auto& iterations : traceSpans_) { - for (const auto& span_pair : iterations.second) { - const TraceSpan& gpu_span = span_pair.second; - if (gpu_span.opCount > 0) { - logger.handleTraceSpan(gpu_span); - } - } - } - - // Overhead info - overheadInfo_.push_back(ActivityLogger::OverheadInfo("CUPTI Overhead")); - for(const auto& info : overheadInfo_) { - logger.handleOverheadInfo(info, captureWindowStartTime_); - } - - gpuUserEventMap_.logEvents(&logger); - -#if !USE_GOOGLE_LOG - // Save logs from LoggerCollector objects into Trace metadata. - auto LoggerMD = loggerCollectorMetadata_->extractCollectorMetadata(); - std::unordered_map> LoggerMDString; - for (auto& md : LoggerMD) { - LoggerMDString[toString(md.first)] = md.second; - } -#endif // !USE_GOOGLE_LOG - - logger.finalizeTrace(config, std::move(traceBuffers_), captureWindowEndTime_, LoggerMDString); -} - -void CuptiActivityProfiler::resetTraceData() { -#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER) - if (!cpuOnly_) { - cupti_.clearActivities(); - } -#endif // HAS_CUPTI || HAS_ROCTRACER - activityMap_.clear(); - cpuCorrelationMap_.clear(); - correlatedCudaActivities_.clear(); - gpuUserEventMap_.clear(); - traceSpans_.clear(); - clientActivityTraceMap_.clear(); - traceBuffers_ = nullptr; - metadata_.clear(); - sessions_.clear(); -#if !USE_GOOGLE_LOG - Logger::removeLoggerObserver(loggerCollectorMetadata_.get()); -#endif // !USE_GOOGLE_LOG -} - - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h b/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h deleted file mode 100644 index 208833a4db720429982a63ed72ffa4762ef00bd0..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h +++ /dev/null @@ -1,364 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "ThreadUtil.h" -#include "TraceSpan.h" -#include "libkineto.h" -#include "output_base.h" -#include "GenericTraceActivity.h" -#include "IActivityProfiler.h" -#include "LoggerCollector.h" - -namespace KINETO_NAMESPACE { - -class Config; -class CuptiActivityApi; -class RoctracerActivityApi; - -class CuptiActivityProfiler { - public: - CuptiActivityProfiler(CuptiActivityApi& cupti, bool cpuOnly); - CuptiActivityProfiler(RoctracerActivityApi& rai, bool cpuOnly); - CuptiActivityProfiler(const CuptiActivityProfiler&) = delete; - CuptiActivityProfiler& operator=(const CuptiActivityProfiler&) = delete; - - bool isActive() const { - return currentRunloopState_ != RunloopState::WaitForRequest; - } - - // Invoke at a regular interval to perform profiling activities. - // When not active, an interval of 1-5 seconds is probably fine, - // depending on required warm-up time and delayed start time. - // When active, it's a good idea to invoke more frequently to stay below - // memory usage limit (ACTIVITIES_MAX_GPU_BUFFER_SIZE_MB) during warmup. - const std::chrono::time_point performRunLoopStep( - const std::chrono::time_point& now, - const std::chrono::time_point& nextWakeupTime, - int64_t currentIter = -1); - - // Used for async requests - void setLogger(ActivityLogger* logger) { - logger_ = logger; - } - - // Synchronous control API - void startTrace( - const std::chrono::time_point& now) { - std::lock_guard guard(mutex_); - startTraceInternal(now); - } - - void stopTrace(const std::chrono::time_point& now) { - std::lock_guard guard(mutex_); - stopTraceInternal(now); - } - - // Process CPU and GPU traces - void processTrace(ActivityLogger& logger) { - std::lock_guard guard(mutex_); - processTraceInternal(logger); - } - - void reset() { - std::lock_guard guard(mutex_); - resetInternal(); - } - - // Set up profiler as specified in config. - void configure( - const Config& config, - const std::chrono::time_point& now); - - // Registered with client API to pass CPU trace events over - void transferCpuTrace( - std::unique_ptr cpuTrace); - - Config& config() { - return *config_; - } - - inline void recordThreadInfo() { - int32_t sysTid = systemThreadId(); - // Note we're using the lower 32 bits of the (opaque) pthread id - // as key, because that's what CUPTI records. - int32_t tid = threadId(); - int32_t pid = processId(); - std::lock_guard guard(mutex_); - recordThreadInfo(sysTid, tid, pid); - } - - // T107508020: We can deprecate the recordThreadInfo(void) once we optimized profiler_kineto - void recordThreadInfo(int32_t sysTid, int32_t tid, int32_t pid) { - if (resourceInfo_.find({pid, tid}) == resourceInfo_.end()) { - resourceInfo_.emplace( - std::make_pair(pid, tid), - ActivityLogger::ResourceInfo( - pid, - sysTid, - sysTid, // sortindex - fmt::format("thread {} ({})", sysTid, getThreadName()))); - } - } - - void addMetadata(const std::string& key, const std::string& value) { - std::lock_guard guard(mutex_); - metadata_[key] = value; - } - - void addChildActivityProfiler( - std::unique_ptr profiler) { - std::lock_guard guard(mutex_); - profilers_.push_back(std::move(profiler)); - } - - protected: - - using CpuGpuSpanPair = std::pair; - static const CpuGpuSpanPair& defaultTraceSpan(); - - private: - - // Map of gpu activities to user defined events - class GpuUserEventMap { - public: - // Insert a user defined event which maps to the gpu trace activity. - // If the user defined event mapping already exists this will update the - // gpu side span to include the span of gpuTraceActivity. - void insertOrExtendEvent(const ITraceActivity& cpuTraceActivity, - const ITraceActivity& gpuTraceActivity); - // Log out the events to the logger - void logEvents(ActivityLogger *logger); - - void clear() { - streamSpanMap_.clear(); - } - - private: - // device id and stream name - using StreamKey = std::pair; - - // map of correlation id to TraceSpan - using CorrelationSpanMap = - std::unordered_map; - std::map streamSpanMap_; - }; - - GpuUserEventMap gpuUserEventMap_; - // id -> activity* - std::unordered_map activityMap_; - // cuda runtime id -> pytorch op id - // CUPTI provides a mechanism for correlating Cuda events to arbitrary - // external events, e.g.operator activities from PyTorch. - std::unordered_map cpuCorrelationMap_; - // CUDA runtime <-> GPU Activity - std::unordered_map - correlatedCudaActivities_; - std::unordered_map userCorrelationMap_; - - // data structure to collect cuptiActivityFlushAll() latency overhead - struct profilerOverhead { - int64_t overhead; - int cntr; - }; - - bool isWarmupDone( - const std::chrono::time_point& now, - int64_t currentIter) const; - - bool isCollectionDone( - const std::chrono::time_point& now, - int64_t currentIter) const; - - void startTraceInternal( - const std::chrono::time_point& now); - - void stopTraceInternal( - const std::chrono::time_point& now); - - void processTraceInternal(ActivityLogger& logger); - - void resetInternal(); - - void finalizeTrace(const Config& config, ActivityLogger& logger); - - void configureChildProfilers(); - - // Process a single CPU trace - void processCpuTrace( - libkineto::CpuTraceBuffer& cpuTrace, - ActivityLogger& logger); - - // Create resource names for streams - inline void recordStream(int device, int id, const char* postfix) { - if (resourceInfo_.find({device, id}) == resourceInfo_.end()) { - resourceInfo_.emplace( - std::make_pair(device, id), - ActivityLogger::ResourceInfo( - device, id, id, fmt::format( - "stream {} {}", id, postfix))); - } - } - - // Record client trace span for subsequent lookups from activities - // Also creates a corresponding GPU-side span. - CpuGpuSpanPair& recordTraceSpan(TraceSpan& span, int gpuOpCount); - - // Returns true if net name is to be tracked for a specified number of - // iterations. - bool iterationTargetMatch(libkineto::CpuTraceBuffer& trace); - - // net name to id - int netId(const std::string& netName); - - const ITraceActivity* linkedActivity( - int32_t correlationId, - const std::unordered_map& correlationMap); - -#ifdef HAS_CUPTI - // Process generic CUPTI activity - void handleCuptiActivity(const CUpti_Activity* record, ActivityLogger* logger); - - // Process specific GPU activity types - void updateGpuNetSpan(const ITraceActivity& gpuOp); - bool outOfRange(const ITraceActivity& act); - void handleCorrelationActivity( - const CUpti_ActivityExternalCorrelation* correlation); - void handleRuntimeActivity( - const CUpti_ActivityAPI* activity, ActivityLogger* logger); - void handleOverheadActivity( - const CUpti_ActivityOverhead* activity, ActivityLogger* logger); - void handleGpuActivity(const ITraceActivity& act, - ActivityLogger* logger); - template - void handleGpuActivity(const T* act, ActivityLogger* logger); -#endif // HAS_CUPTI - - void resetTraceData(); - - void addOverheadSample(profilerOverhead& counter, int64_t overhead) { - counter.overhead += overhead; - counter.cntr++; - } - int64_t getOverhead(const profilerOverhead& counter) { - if (counter.cntr == 0) { - return 0; - } - return counter.overhead / counter.cntr; - } - - void checkTimestampOrder(const ITraceActivity* act1); - - // On-demand request configuration - std::unique_ptr config_; - - // Logger used during trace processing - ActivityLogger* logger_; - - // Calls to CUPTI is encapsulated behind this interface -#ifdef HAS_ROCTRACER - RoctracerActivityApi& cupti_; // Design failure here -#else - CuptiActivityApi& cupti_; -#endif - - enum class RunloopState { - WaitForRequest, - Warmup, - CollectTrace, - ProcessTrace - }; - - // Start and end time used for triggering and stopping profiling - std::chrono::time_point profileStartTime_; - std::chrono::time_point profileEndTime_; - int64_t profileStartIter_ = -1, profileEndIter_ = -1; - - - // All recorded trace spans, both CPU and GPU - // Trace Id -> list of iterations. - // Using map of lists for the iterator semantics, since we are recording - // pointers to the elements in this structure. - std::map> traceSpans_; - - // Maintain a map of client trace activity to trace span. - // Maps correlation id -> TraceSpan* held by traceSpans_. - using ActivityTraceMap = std::unordered_map; - ActivityTraceMap clientActivityTraceMap_; - - // Cache thread names and system thread ids for pthread ids, - // and stream ids for GPU streams - std::map< - std::pair, - ActivityLogger::ResourceInfo> resourceInfo_; - - std::vector overheadInfo_; - - // the overhead to flush the activity buffer - profilerOverhead flushOverhead_; - // the overhead to enable/disable activity tracking - profilerOverhead setupOverhead_; - - bool cpuOnly_{false}; - - // *************************************************************************** - // Below state is shared with external threads. - // These need to either be atomic, accessed under lock or only used - // by external threads in separate runloop phases from the profiler thread. - // *************************************************************************** - - // Mutex to protect non-atomic access to below state - std::mutex mutex_; - - // Runloop phase - std::atomic currentRunloopState_{RunloopState::WaitForRequest}; - - // Keep track of the start time of the first net in the current trace. - // This is only relevant to Caffe2 as PyTorch does not have nets. - // All CUDA events before this time will be removed - // Can be written by external threads during collection. - int64_t captureWindowStartTime_{0}; - // Similarly, all CUDA API events after the last net event will be removed - int64_t captureWindowEndTime_{0}; - - // span name -> iteration count - std::map iterationCountMap_; - // Flag used to stop tracing from external api callback. - // Needs to be atomic since it's set from a different thread. - std::atomic_bool stopCollection_{false}; - - // Buffers where trace data is stored - std::unique_ptr traceBuffers_; - - // Trace metadata - std::unordered_map metadata_; - - // child activity profilers - std::vector> profilers_; - - // a vector of active profiler plugin sessions - std::vector> sessions_; - - // LoggerCollector to collect all LOGs during the trace -#if !USE_GOOGLE_LOG - std::unique_ptr loggerCollectorMetadata_; -#endif // !USE_GOOGLE_LOG -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp deleted file mode 100644 index 1876003998dc0c66f882d939ca8100750cfd046a..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp +++ /dev/null @@ -1,260 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "CuptiCallbackApi.h" - -#include -#include -#include -#include -#include - -#ifdef HAS_CUPTI -#include "cupti_call.h" -#endif -#include "Logger.h" - - -namespace KINETO_NAMESPACE { - -// limit on number of handles per callback type -constexpr size_t MAX_CB_FNS_PER_CB = 8; - -// Reader Writer lock types -using ReaderWriterLock = std::shared_timed_mutex; -using ReaderLockGuard = std::shared_lock; -using WriteLockGuard = std::unique_lock; - -static ReaderWriterLock callbackLock_; - -/* Callback Table : - * Overall goal of the design is to optimize the lookup of function - * pointers. The table is structured at two levels and the leaf - * elements in the table are std::list to enable fast access/inserts/deletes - * - * | - * -> cb id 0 -> std::list of callbacks - * ... - * -> cb id n -> std::list of callbacks - * | - * ... - * CallbackTable is the finaly table type above - * See type declrartions in header file. - */ - - -/* callback_switchboard : is the global callback handler we register - * with CUPTI. The goal is to make it as efficient as possible - * to re-direct to the registered callback(s). - * - * Few things to care about : - * a) use if/then switches rather than map/hash structures - * b) avoid dynamic memory allocations - * c) be aware of locking overheads - */ -#ifdef HAS_CUPTI -static void CUPTIAPI callback_switchboard( -#else -static void callback_switchboard( -#endif - void* /* unused */, - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData* cbInfo) { - - // below statement is likey going to call a mutex - // on the singleton access - CuptiCallbackApi::singleton().__callback_switchboard( - domain, cbid, cbInfo); -} - - -void CuptiCallbackApi::__callback_switchboard( - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData* cbInfo) { - VLOG(0) << "Callback: domain = " << domain << ", cbid = " << cbid; - CallbackList *cblist = nullptr; - - switch (domain) { - - // add the fastest path for kernel launch callbacks - // as these are the most frequent ones - case CUPTI_CB_DOMAIN_RUNTIME_API: - switch (cbid) { - case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000: - cblist = &callbacks_.runtime[ - CUDA_LAUNCH_KERNEL - __RUNTIME_CB_DOMAIN_START]; - break; - default: - break; - } - break; - - case CUPTI_CB_DOMAIN_RESOURCE: - switch (cbid) { - case CUPTI_CBID_RESOURCE_CONTEXT_CREATED: - cblist = &callbacks_.resource[ - RESOURCE_CONTEXT_CREATED - __RESOURCE_CB_DOMAIN_START]; - break; - case CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING: - cblist = &callbacks_.resource[ - RESOURCE_CONTEXT_DESTROYED - __RESOURCE_CB_DOMAIN_START]; - break; - default: - break; - } - break; - - default: - return; - } - - // ignore callbacks that are not handled - if (cblist == nullptr) { - return; - } - - // make a copy of the callback list so we avoid holding lock - // in common case this should be just one func pointer copy - std::array callbacks; - int num_cbs = 0; - { - ReaderLockGuard rl(callbackLock_); - int i = 0; - for (auto it = cblist->begin(); - it != cblist->end() && i < MAX_CB_FNS_PER_CB; - it++, i++) { - callbacks[i] = *it; - } - num_cbs = i; - } - - for (int i = 0; i < num_cbs; i++) { - auto fn = callbacks[i]; - fn(domain, cbid, cbInfo); - } -} - -CuptiCallbackApi& CuptiCallbackApi::singleton() { - static CuptiCallbackApi instance; - return instance; -} - -CuptiCallbackApi::CuptiCallbackApi() { -#ifdef HAS_CUPTI - lastCuptiStatus_ = CUPTI_ERROR_UNKNOWN; - lastCuptiStatus_ = CUPTI_CALL_NOWARN( - cuptiSubscribe(&subscriber_, - (CUpti_CallbackFunc)callback_switchboard, - nullptr)); - - initSuccess_ = (lastCuptiStatus_ == CUPTI_SUCCESS); -#endif -} - -CuptiCallbackApi::CallbackList* CuptiCallbackApi::CallbackTable::lookup( - CUpti_CallbackDomain domain, CuptiCallBackID cbid) { - size_t idx; - - switch (domain) { - - case CUPTI_CB_DOMAIN_RESOURCE: - assert(cbid >= __RESOURCE_CB_DOMAIN_START); - assert(cbid < __RESOURCE_CB_DOMAIN_END); - idx = cbid - __RESOURCE_CB_DOMAIN_START; - return &resource.at(idx); - - case CUPTI_CB_DOMAIN_RUNTIME_API: - assert(cbid >= __RUNTIME_CB_DOMAIN_START); - assert(cbid < __RUNTIME_CB_DOMAIN_END); - idx = cbid - __RUNTIME_CB_DOMAIN_START; - return &runtime.at(idx); - - default: - LOG(WARNING) << " Unsupported callback domain : " << domain; - return nullptr; - } -} - -bool CuptiCallbackApi::registerCallback( - CUpti_CallbackDomain domain, - CuptiCallBackID cbid, - CuptiCallbackFn cbfn) { - CallbackList* cblist = callbacks_.lookup(domain, cbid); - - if (!cblist) { - LOG(WARNING) << "Could not register callback -- domain = " << domain - << " callback id = " << cbid; - return false; - } - - // avoid duplicates - auto it = std::find(cblist->begin(), cblist->end(), cbfn); - if (it != cblist->end()) { - LOG(WARNING) << "Adding duplicate callback -- domain = " << domain - << " callback id = " << cbid; - return true; - } - - if (cblist->size() == MAX_CB_FNS_PER_CB) { - LOG(WARNING) << "Already registered max callback -- domain = " << domain - << " callback id = " << cbid; - } - - WriteLockGuard wl(callbackLock_); - cblist->push_back(cbfn); - return true; -} - -bool CuptiCallbackApi::deleteCallback( - CUpti_CallbackDomain domain, - CuptiCallBackID cbid, - CuptiCallbackFn cbfn) { - CallbackList* cblist = callbacks_.lookup(domain, cbid); - if (!cblist) { - LOG(WARNING) << "Attempting to remove unsupported callback -- domain = " << domain - << " callback id = " << cbid; - return false; - } - - // Locks are not required here as - // https://en.cppreference.com/w/cpp/container/list/erase - // "References and iterators to the erased elements are invalidated. - // Other references and iterators are not affected." - auto it = std::find(cblist->begin(), cblist->end(), cbfn); - if (it == cblist->end()) { - LOG(WARNING) << "Could not find callback to remove -- domain = " << domain - << " callback id = " << cbid; - return false; - } - - WriteLockGuard wl(callbackLock_); - cblist->erase(it); - return true; -} - -bool CuptiCallbackApi::enableCallback( - CUpti_CallbackDomain domain, CUpti_CallbackId cbid) { -#ifdef HAS_CUPTI - if (initSuccess_) { - lastCuptiStatus_ = CUPTI_CALL_NOWARN( - cuptiEnableCallback(1, subscriber_, domain, cbid)); - return (lastCuptiStatus_ == CUPTI_SUCCESS); - } -#endif - return false; -} - -bool CuptiCallbackApi::disableCallback( - CUpti_CallbackDomain domain, CUpti_CallbackId cbid) { -#ifdef HAS_CUPTI - if (initSuccess_) { - lastCuptiStatus_ = CUPTI_CALL_NOWARN( - cuptiEnableCallback(0, subscriber_, domain, cbid)); - return (lastCuptiStatus_ == CUPTI_SUCCESS); - } -#endif - return false; -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h deleted file mode 100644 index 4526f3750b4a134bc888843b8ff347a1f2bf8d5f..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h +++ /dev/null @@ -1,130 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#ifdef HAS_CUPTI -#include -#endif -#include -#include -#include -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "CuptiCallbackApiMock.h" - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - - -/* CuptiCallbackApi : Provides an abstraction over CUPTI callback - * interface. This enables various callback functions to be registered - * with this class. The class registers a global callback handler that - * redirects to the respective callbacks. - * - * Note: one design choice we made is to only support simple function pointers - * in order to speed up the implementation for fast path. - */ - -using CuptiCallbackFn = void(*)( - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData* cbInfo); - - -class CuptiCallbackApi { - - public: - - /* Global list of supported callback ids - * use the class namespace to avoid confusing with CUPTI enums*/ - enum CuptiCallBackID { - CUDA_LAUNCH_KERNEL = 0, - // can possibly support more callback ids per domain - // - __RUNTIME_CB_DOMAIN_START = CUDA_LAUNCH_KERNEL, - - // Callbacks under Resource CB domain - RESOURCE_CONTEXT_CREATED, - RESOURCE_CONTEXT_DESTROYED, - - __RUNTIME_CB_DOMAIN_END = RESOURCE_CONTEXT_CREATED, - __RESOURCE_CB_DOMAIN_START = RESOURCE_CONTEXT_CREATED, - - __RESOURCE_CB_DOMAIN_END = RESOURCE_CONTEXT_DESTROYED + 1, - }; - - - CuptiCallbackApi(const CuptiCallbackApi&) = delete; - CuptiCallbackApi& operator=(const CuptiCallbackApi&) = delete; - - static CuptiCallbackApi& singleton(); - - bool initSuccess() const { - return initSuccess_; - } - -#ifdef HAS_CUPTI - CUptiResult getCuptiStatus() const { - return lastCuptiStatus_; - } -#endif - - bool registerCallback( - CUpti_CallbackDomain domain, - CuptiCallBackID cbid, - CuptiCallbackFn cbfn); - - // returns false if callback was not found - bool deleteCallback( - CUpti_CallbackDomain domain, - CuptiCallBackID cbid, - CuptiCallbackFn cbfn); - - bool enableCallback(CUpti_CallbackDomain domain, CUpti_CallbackId cbid); - bool disableCallback(CUpti_CallbackDomain domain, CUpti_CallbackId cbid); - - - // Please do not use this method. This has to be exposed as public - // so it is accessible from the callback handler - void __callback_switchboard( - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData* cbInfo); - - private: - - explicit CuptiCallbackApi(); - - // For callback table design overview see the .cpp file - using CallbackList = std::list; - - // level 2 tables sizes are known at compile time - constexpr static size_t RUNTIME_CB_DOMAIN_SIZE - = (__RUNTIME_CB_DOMAIN_END - __RUNTIME_CB_DOMAIN_START); - - constexpr static size_t RESOURCE_CB_DOMAIN_SIZE - = (__RESOURCE_CB_DOMAIN_END - __RESOURCE_CB_DOMAIN_START); - - // level 1 table is a struct - struct CallbackTable { - std::array runtime; - std::array resource; - - CallbackList* lookup(CUpti_CallbackDomain domain, CuptiCallBackID cbid); - }; - - CallbackTable callbacks_; - bool initSuccess_ = false; - -#ifdef HAS_CUPTI - CUptiResult lastCuptiStatus_; - CUpti_SubscriberHandle subscriber_; -#endif -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h b/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h deleted file mode 100644 index fd51267274f99a0c9949eaac6fdae2dff917c7a0..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h +++ /dev/null @@ -1,32 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -// Provides data structures to mock CUPTI Callback API -#ifndef HAS_CUPTI - -enum CUpti_CallbackDomain { - CUPTI_CB_DOMAIN_RESOURCE, - CUPTI_CB_DOMAIN_RUNTIME_API, -}; -enum CUpti_CallbackId { - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000, - CUPTI_CBID_RESOURCE_CONTEXT_CREATED, - CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING, -}; - -using CUcontext = void*; - -struct CUpti_ResourceData { - CUcontext context; -}; - -constexpr int CUPTI_API_ENTER = 0; -constexpr int CUPTI_API_EXIT = 0; - -struct CUpti_CallbackData { - CUcontext context; - const char* symbolName; - int callbackSite; -}; -#endif // HAS_CUPTI diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp deleted file mode 100644 index 7f1d48c1d00bb7defb6b622c13da55da99312a3b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp +++ /dev/null @@ -1,112 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "CuptiEventApi.h" - -#include - -#include "Logger.h" -#include "cupti_call.h" - -using namespace std::chrono; -using std::vector; - -namespace KINETO_NAMESPACE { - -CuptiEventApi::CuptiEventApi(CUcontext context) - : context_(context) { - CUPTI_CALL(cuptiGetDeviceId(context_, (uint32_t*)&device_)); -} - -CUpti_EventGroupSets* CuptiEventApi::createGroupSets( - vector& ids) { - CUpti_EventGroupSets* group_sets = nullptr; - CUptiResult res = CUPTI_CALL(cuptiEventGroupSetsCreate( - context_, sizeof(CUpti_EventID) * ids.size(), ids.data(), &group_sets)); - - if (res != CUPTI_SUCCESS || group_sets == nullptr) { - const char* errstr = nullptr; - CUPTI_CALL(cuptiGetResultString(res, &errstr)); - throw std::system_error(EINVAL, std::generic_category(), errstr); - } - - return group_sets; -} - -void CuptiEventApi::destroyGroupSets(CUpti_EventGroupSets* sets) { - CUPTI_CALL(cuptiEventGroupSetsDestroy(sets)); -} - -bool CuptiEventApi::setContinuousMode() { - // Avoid logging noise for CUPTI_ERROR_LEGACY_PROFILER_NOT_SUPPORTED - CUptiResult res = CUPTI_CALL_NOWARN(cuptiSetEventCollectionMode( - context_, CUPTI_EVENT_COLLECTION_MODE_CONTINUOUS)); - if (res == CUPTI_ERROR_LEGACY_PROFILER_NOT_SUPPORTED) { - return false; - } - // Log warning on other errors - CUPTI_CALL(res); - return (res == CUPTI_SUCCESS); -} - -void CuptiEventApi::enablePerInstance(CUpti_EventGroup eventGroup) { - uint32_t profile_all = 1; - CUPTI_CALL(cuptiEventGroupSetAttribute( - eventGroup, - CUPTI_EVENT_GROUP_ATTR_PROFILE_ALL_DOMAIN_INSTANCES, - sizeof(profile_all), - &profile_all)); -} - -uint32_t CuptiEventApi::instanceCount(CUpti_EventGroup eventGroup) { - uint32_t instance_count = 0; - size_t s = sizeof(instance_count); - CUPTI_CALL(cuptiEventGroupGetAttribute( - eventGroup, CUPTI_EVENT_GROUP_ATTR_INSTANCE_COUNT, &s, &instance_count)); - return instance_count; -} - -void CuptiEventApi::enableGroupSet(CUpti_EventGroupSet& set) { - CUptiResult res = CUPTI_CALL_NOWARN(cuptiEventGroupSetEnable(&set)); - if (res != CUPTI_SUCCESS) { - const char* errstr = nullptr; - CUPTI_CALL(cuptiGetResultString(res, &errstr)); - throw std::system_error(EIO, std::generic_category(), errstr); - } -} - -void CuptiEventApi::disableGroupSet(CUpti_EventGroupSet& set) { - CUPTI_CALL(cuptiEventGroupSetDisable(&set)); -} - -void CuptiEventApi::readEvent( - CUpti_EventGroup grp, - CUpti_EventID id, - vector& vals) { - size_t s = sizeof(int64_t) * vals.size(); - CUPTI_CALL(cuptiEventGroupReadEvent( - grp, - CUPTI_EVENT_READ_FLAG_NONE, - id, - &s, - reinterpret_cast(vals.data()))); -} - -vector CuptiEventApi::eventsInGroup(CUpti_EventGroup grp) { - uint32_t group_size = 0; - size_t s = sizeof(group_size); - CUPTI_CALL(cuptiEventGroupGetAttribute( - grp, CUPTI_EVENT_GROUP_ATTR_NUM_EVENTS, &s, &group_size)); - size_t events_size = group_size * sizeof(CUpti_EventID); - vector res(group_size); - CUPTI_CALL(cuptiEventGroupGetAttribute( - grp, CUPTI_EVENT_GROUP_ATTR_EVENTS, &events_size, res.data())); - return res; -} - -CUpti_EventID CuptiEventApi::eventId(const std::string& name) { - CUpti_EventID id{0}; - CUPTI_CALL(cuptiEventGetIdFromName(device_, name.c_str(), &id)); - return id; -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h deleted file mode 100644 index 79610f93f0ecfa62a9508d4caddfa876518169d3..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h +++ /dev/null @@ -1,49 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include - -namespace KINETO_NAMESPACE { - -// C++ interface to CUPTI Events C API. -// Virtual methods are here mainly to allow easier testing. -class CuptiEventApi { - public: - explicit CuptiEventApi(CUcontext context_); - virtual ~CuptiEventApi() {} - - CUdevice device() { - return device_; - } - - virtual CUpti_EventGroupSets* createGroupSets( - std::vector& ids); - virtual void destroyGroupSets(CUpti_EventGroupSets* sets); - - virtual bool setContinuousMode(); - - virtual void enablePerInstance(CUpti_EventGroup eventGroup); - virtual uint32_t instanceCount(CUpti_EventGroup eventGroup); - - virtual void enableGroupSet(CUpti_EventGroupSet& set); - virtual void disableGroupSet(CUpti_EventGroupSet& set); - - virtual void - readEvent(CUpti_EventGroup g, CUpti_EventID id, std::vector& vals); - virtual std::vector eventsInGroup(CUpti_EventGroup g); - - virtual CUpti_EventID eventId(const std::string& name); - - protected: - // Unit testing - CuptiEventApi() : context_(nullptr), device_(0) {} - - private: - CUcontext context_; - CUdevice device_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp deleted file mode 100644 index 36401e7434108d1da079aa4ba0264192c5d62838..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp +++ /dev/null @@ -1,107 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "CuptiMetricApi.h" - -#include - -#include "Logger.h" -#include "cupti_call.h" - -using namespace std::chrono; -using std::vector; - -namespace KINETO_NAMESPACE { - -CUpti_MetricID CuptiMetricApi::idFromName(const std::string& name) { - CUpti_MetricID metric_id{~0u}; - CUptiResult res = - CUPTI_CALL(cuptiMetricGetIdFromName(device_, name.c_str(), &metric_id)); - if (res == CUPTI_ERROR_INVALID_METRIC_NAME) { - LOG(WARNING) << "Invalid metric name: " << name; - } - return metric_id; -} - -// Return a map of event IDs and names for a given metric id. -// Note that many events don't have a name. In that case the name will -// be set to the empty string. -std::map CuptiMetricApi::events( - CUpti_MetricID metric_id) { - uint32_t num_events = 0; - CUPTI_CALL(cuptiMetricGetNumEvents(metric_id, &num_events)); - vector ids(num_events); - size_t array_size = num_events * sizeof(CUpti_EventID); - CUPTI_CALL(cuptiMetricEnumEvents(metric_id, &array_size, ids.data())); - std::map res; - for (CUpti_EventID id : ids) { - // Attempt to lookup name from CUPTI - constexpr size_t kMaxEventNameLength = 64; - char cupti_name[kMaxEventNameLength]; - size_t size = kMaxEventNameLength; - CUPTI_CALL( - cuptiEventGetAttribute(id, CUPTI_EVENT_ATTR_NAME, &size, cupti_name)); - cupti_name[kMaxEventNameLength - 1] = 0; - - // CUPTI "helpfully" returns "event_name" when the event is unnamed. - if (size > 0 && strcmp(cupti_name, "event_name") != 0) { - res.emplace(id, cupti_name); - } else { - res.emplace(id, ""); - } - } - return res; -} - -CUpti_MetricValueKind CuptiMetricApi::valueKind(CUpti_MetricID metric) { - CUpti_MetricValueKind res{CUPTI_METRIC_VALUE_KIND_FORCE_INT}; - size_t value_kind_size = sizeof(res); - CUPTI_CALL(cuptiMetricGetAttribute( - metric, CUPTI_METRIC_ATTR_VALUE_KIND, &value_kind_size, &res)); - return res; -} - -CUpti_MetricEvaluationMode CuptiMetricApi::evaluationMode( - CUpti_MetricID metric) { - CUpti_MetricEvaluationMode eval_mode{ - CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE}; - size_t eval_mode_size = sizeof(eval_mode); - CUPTI_CALL(cuptiMetricGetAttribute( - metric, CUPTI_METRIC_ATTR_EVALUATION_MODE, &eval_mode_size, &eval_mode)); - return eval_mode; -} - -// FIXME: Consider caching value kind here -SampleValue CuptiMetricApi::calculate( - CUpti_MetricID metric, - CUpti_MetricValueKind kind, - vector& events, - vector& values, - int64_t duration) { - CUpti_MetricValue metric_value; - CUPTI_CALL(cuptiMetricGetValue( - device_, - metric, - events.size() * sizeof(CUpti_EventID), - events.data(), - values.size() * sizeof(int64_t), - reinterpret_cast(values.data()), - duration, - &metric_value)); - - switch (kind) { - case CUPTI_METRIC_VALUE_KIND_DOUBLE: - case CUPTI_METRIC_VALUE_KIND_PERCENT: - return SampleValue(metric_value.metricValueDouble); - case CUPTI_METRIC_VALUE_KIND_UINT64: - case CUPTI_METRIC_VALUE_KIND_INT64: - case CUPTI_METRIC_VALUE_KIND_THROUGHPUT: - return SampleValue(metric_value.metricValueUint64); - case CUPTI_METRIC_VALUE_KIND_UTILIZATION_LEVEL: - return SampleValue((int)metric_value.metricValueUtilizationLevel); - default: - assert(false); - } - return SampleValue(-1); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h deleted file mode 100644 index f45d38cd6169dc7fd30208dbb7dac09fd8a9dee5..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h +++ /dev/null @@ -1,38 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -#include -#include - -#include "SampleListener.h" - -namespace KINETO_NAMESPACE { - -// C++ interface to CUPTI Metrics C API. -// Virtual methods are here mainly to allow easier testing. -class CuptiMetricApi { - public: - explicit CuptiMetricApi(CUdevice device) : device_(device) {} - virtual ~CuptiMetricApi() {} - - virtual CUpti_MetricID idFromName(const std::string& name); - virtual std::map events(CUpti_MetricID metric_id); - - virtual CUpti_MetricValueKind valueKind(CUpti_MetricID metric); - virtual CUpti_MetricEvaluationMode evaluationMode(CUpti_MetricID metric); - - virtual SampleValue calculate( - CUpti_MetricID metric, - CUpti_MetricValueKind kind, - std::vector& events, - std::vector& values, - int64_t duration); - - private: - CUdevice device_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp deleted file mode 100644 index d1b08ab2c13d0615221e71f43f07c3d3fe102a2f..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp +++ /dev/null @@ -1,504 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#ifdef HAS_CUPTI -#include -#if defined(CUDART_VERSION) && CUDART_VERSION > 10000 && CUDART_VERSION < 11040 -#include -#include -#include -#endif // cuda version > 10.00 and < 11.04 -#endif // HAS_CUPTI - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "ScopeExit.h" -#include "CuptiNvPerfMetric.h" -#include "Logger.h" - -namespace KINETO_NAMESPACE { - -// Add a namespace to isolate these utility functions that are only -// going to be used by the CuptiRangeProfiler. These included calls -// to NVIDIA PerfWorks APIs. -namespace nvperf { - - -// Largely based on NVIDIA sample code provided with CUDA release -// files Metric.cpp and Eval.cpp - -// ------------------------------------------------- -// Metric and Counter Data Configuration -// ------------------------------------------------- - - -// Note: Be carful before modifying the code below. There is a specific -// sequence one needs to follow to program the metrics else things may -// stop working. We tried to keep the flow consistent with the example -// code from NVIDIA. Since most of the programmability comes from -// the CUPTI profiler metric names this should be okay. - -// Only supported on CUDA RT Version between 10.0 and 11.04. -// After CUDA RT 11.04, the structure has changed. -// TODO update the structure NVPA_RawMetricsConfig to support 11.04 -#if defined(CUDART_VERSION) && CUDART_VERSION > 10000 && CUDART_VERSION < 11040 - -bool getRawMetricRequests( - NVPA_MetricsContext* metricsContext, - std::vector metricNames, - std::vector& rawMetricsDeps, - std::vector& rawMetricRequests) { - bool isolated = true; - /* Bug in collection with collection of metrics without instances, keep it - * to true*/ - bool keepInstances = true; - - for (const auto& metricName : metricNames) { - - NVPW_MetricsContext_GetMetricProperties_Begin_Params - getMetricPropertiesBeginParams = { - NVPW_MetricsContext_GetMetricProperties_Begin_Params_STRUCT_SIZE, nullptr}; - getMetricPropertiesBeginParams.pMetricsContext = metricsContext; - getMetricPropertiesBeginParams.pMetricName = metricName.c_str(); - - if (!NVPW_CALL( - NVPW_MetricsContext_GetMetricProperties_Begin( - &getMetricPropertiesBeginParams))) { - return false; - } - - for (const char** metricDepsIt = - getMetricPropertiesBeginParams.ppRawMetricDependencies; - *metricDepsIt; - ++metricDepsIt) { - rawMetricsDeps.push_back(*metricDepsIt); - } - - NVPW_MetricsContext_GetMetricProperties_End_Params - getMetricPropertiesEndParams = { - NVPW_MetricsContext_GetMetricProperties_End_Params_STRUCT_SIZE, nullptr}; - getMetricPropertiesEndParams.pMetricsContext = metricsContext; - - if (!NVPW_CALL(NVPW_MetricsContext_GetMetricProperties_End( - &getMetricPropertiesEndParams))) { - return false; - } - } - - for (const auto& rawMetricName : rawMetricsDeps) { - NVPA_RawMetricRequest metricRequest = {NVPA_RAW_METRIC_REQUEST_STRUCT_SIZE, nullptr}; - metricRequest.pMetricName = rawMetricName.c_str(); - metricRequest.isolated = isolated; - metricRequest.keepInstances = keepInstances; - rawMetricRequests.push_back(metricRequest); - VLOG(1) << "Adding raw metric struct : raw metric = " << rawMetricName - << " isolated = " << isolated << " keepinst = " << keepInstances; - } - - if (rawMetricRequests.size() == 0) { - LOG(WARNING) << "CUPTI Profiler was unable to configure any metrics"; - return false; - } - return true; -} - -// Setup CUPTI Profiler Config Image -bool getProfilerConfigImage( - const std::string& chipName, - const std::vector& metricNames, - std::vector& configImage, - const uint8_t* counterAvailabilityImage) { - - NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { - NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE, nullptr}; - metricsContextCreateParams.pChipName = chipName.c_str(); - - if (!NVPW_CALL( - NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams))) { - return false; - } - - NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { - NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE, nullptr}; - metricsContextDestroyParams.pMetricsContext = - metricsContextCreateParams.pMetricsContext; - - SCOPE_EXIT([&]() { - NVPW_MetricsContext_Destroy( - (NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams); - }); - - // Get all raw metrics required for given metricNames list - std::vector rawMetricRequests; - - // note: we need a variable at this functions scope to hold the string - // pointers for underlying C char arrays. - std::vector rawMetricDeps; - - if (!getRawMetricRequests( - metricsContextCreateParams.pMetricsContext, - metricNames, - rawMetricDeps, - rawMetricRequests)) { - return false; - } - - NVPA_RawMetricsConfigOptions metricsConfigOptions = { - NVPA_RAW_METRICS_CONFIG_OPTIONS_STRUCT_SIZE, nullptr}; - metricsConfigOptions.activityKind = NVPA_ACTIVITY_KIND_PROFILER; - metricsConfigOptions.pChipName = chipName.c_str(); - NVPA_RawMetricsConfig* rawMetricsConfig; - if (!NVPW_CALL( - NVPA_RawMetricsConfig_Create( - &metricsConfigOptions, &rawMetricsConfig))) { - return false; - } - - // TODO check if this is required - if (counterAvailabilityImage) { - NVPW_RawMetricsConfig_SetCounterAvailability_Params - setCounterAvailabilityParams = { - NVPW_RawMetricsConfig_SetCounterAvailability_Params_STRUCT_SIZE, nullptr}; - setCounterAvailabilityParams.pRawMetricsConfig = rawMetricsConfig; - setCounterAvailabilityParams.pCounterAvailabilityImage = - counterAvailabilityImage; - if (!NVPW_CALL( - NVPW_RawMetricsConfig_SetCounterAvailability( - &setCounterAvailabilityParams))) { - return false; - } - } - - NVPW_RawMetricsConfig_Destroy_Params rawMetricsConfigDestroyParams = { - NVPW_RawMetricsConfig_Destroy_Params_STRUCT_SIZE, nullptr}; - rawMetricsConfigDestroyParams.pRawMetricsConfig = rawMetricsConfig; - SCOPE_EXIT([&]() { - NVPW_RawMetricsConfig_Destroy( - (NVPW_RawMetricsConfig_Destroy_Params*)&rawMetricsConfigDestroyParams); - }); - - // Start a Raw Metric Pass group - NVPW_RawMetricsConfig_BeginPassGroup_Params beginPassGroupParams = { - NVPW_RawMetricsConfig_BeginPassGroup_Params_STRUCT_SIZE, nullptr}; - beginPassGroupParams.pRawMetricsConfig = rawMetricsConfig; - if (!NVPW_CALL( - NVPW_RawMetricsConfig_BeginPassGroup(&beginPassGroupParams))) { - return false; - } - - // Add all raw metrics - NVPW_RawMetricsConfig_AddMetrics_Params addMetricsParams = { - NVPW_RawMetricsConfig_AddMetrics_Params_STRUCT_SIZE, nullptr}; - addMetricsParams.pRawMetricsConfig = rawMetricsConfig; - addMetricsParams.pRawMetricRequests = rawMetricRequests.data(); - addMetricsParams.numMetricRequests = rawMetricRequests.size(); - if (!NVPW_CALL( - NVPW_RawMetricsConfig_AddMetrics(&addMetricsParams))) { - return false; - } - - // End pass group - NVPW_RawMetricsConfig_EndPassGroup_Params endPassGroupParams = { - NVPW_RawMetricsConfig_EndPassGroup_Params_STRUCT_SIZE, nullptr}; - endPassGroupParams.pRawMetricsConfig = rawMetricsConfig; - if (!NVPW_CALL( - NVPW_RawMetricsConfig_EndPassGroup(&endPassGroupParams))) { - return false; - } - - // Setup Config Image generation - NVPW_RawMetricsConfig_GenerateConfigImage_Params generateConfigImageParams = { - NVPW_RawMetricsConfig_GenerateConfigImage_Params_STRUCT_SIZE, nullptr}; - generateConfigImageParams.pRawMetricsConfig = rawMetricsConfig; - if (!NVPW_CALL( - NVPW_RawMetricsConfig_GenerateConfigImage(&generateConfigImageParams))) { - return false; - } - - // Get the Config Image size... nearly there - NVPW_RawMetricsConfig_GetConfigImage_Params getConfigImageParams = { - NVPW_RawMetricsConfig_GetConfigImage_Params_STRUCT_SIZE, nullptr}; - getConfigImageParams.pRawMetricsConfig = rawMetricsConfig; - getConfigImageParams.bytesAllocated = 0; - getConfigImageParams.pBuffer = nullptr; - if (!NVPW_CALL( - NVPW_RawMetricsConfig_GetConfigImage(&getConfigImageParams))) { - return false; - } - - configImage.resize(getConfigImageParams.bytesCopied); - - // Write the Config image binary - getConfigImageParams.bytesAllocated = configImage.size(); - getConfigImageParams.pBuffer = configImage.data(); - if (!NVPW_CALL( - NVPW_RawMetricsConfig_GetConfigImage(&getConfigImageParams))) { - return false; - } - - return true; -} - -bool getCounterDataPrefixImage( - const std::string& chipName, - const std::vector& metricNames, - std::vector& counterDataImagePrefix) { - - NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { - NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE, nullptr}; - metricsContextCreateParams.pChipName = chipName.c_str(); - - if (!NVPW_CALL( - NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams))) { - return false; - } - - NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { - NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE, nullptr}; - metricsContextDestroyParams.pMetricsContext = - metricsContextCreateParams.pMetricsContext; - - - SCOPE_EXIT([&]() { - NVPW_MetricsContext_Destroy( - (NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams); - }); - - // Get all raw metrics required for given metricNames list - std::vector rawMetricRequests; - - // note: we need a variable at this functions scope to hold the string - // pointers for underlying C char arrays. - std::vector rawMetricDeps; - - if (!getRawMetricRequests( - metricsContextCreateParams.pMetricsContext, - metricNames, - rawMetricDeps, - rawMetricRequests)) { - return false; - } - - // Setup Counter Data builder - NVPW_CounterDataBuilder_Create_Params counterDataBuilderCreateParams = { - NVPW_CounterDataBuilder_Create_Params_STRUCT_SIZE, nullptr}; - counterDataBuilderCreateParams.pChipName = chipName.c_str(); - if (!NVPW_CALL( - NVPW_CounterDataBuilder_Create(&counterDataBuilderCreateParams))) { - return false; - } - - NVPW_CounterDataBuilder_Destroy_Params counterDataBuilderDestroyParams = { - NVPW_CounterDataBuilder_Destroy_Params_STRUCT_SIZE, nullptr}; - counterDataBuilderDestroyParams.pCounterDataBuilder = - counterDataBuilderCreateParams.pCounterDataBuilder; - SCOPE_EXIT([&]() { - NVPW_CounterDataBuilder_Destroy(( - NVPW_CounterDataBuilder_Destroy_Params*)&counterDataBuilderDestroyParams); - }); - - // Add metrics to counter data image prefix - NVPW_CounterDataBuilder_AddMetrics_Params addMetricsParams = { - NVPW_CounterDataBuilder_AddMetrics_Params_STRUCT_SIZE, nullptr}; - addMetricsParams.pCounterDataBuilder = - counterDataBuilderCreateParams.pCounterDataBuilder; - addMetricsParams.pRawMetricRequests = rawMetricRequests.data(); - addMetricsParams.numMetricRequests = rawMetricRequests.size(); - if (!NVPW_CALL( - NVPW_CounterDataBuilder_AddMetrics(&addMetricsParams))) { - return false; - } - - // Get image prefix size - NVPW_CounterDataBuilder_GetCounterDataPrefix_Params - getCounterDataPrefixParams = { - NVPW_CounterDataBuilder_GetCounterDataPrefix_Params_STRUCT_SIZE, nullptr}; - getCounterDataPrefixParams.pCounterDataBuilder = - counterDataBuilderCreateParams.pCounterDataBuilder; - getCounterDataPrefixParams.bytesAllocated = 0; - getCounterDataPrefixParams.pBuffer = nullptr; - if (!NVPW_CALL( - NVPW_CounterDataBuilder_GetCounterDataPrefix( - &getCounterDataPrefixParams))) { - return false; - } - - counterDataImagePrefix.resize(getCounterDataPrefixParams.bytesCopied); - - // Now write counter data image prefix - getCounterDataPrefixParams.bytesAllocated = counterDataImagePrefix.size(); - getCounterDataPrefixParams.pBuffer = counterDataImagePrefix.data(); - if (!NVPW_CALL( - NVPW_CounterDataBuilder_GetCounterDataPrefix( - &getCounterDataPrefixParams))) { - return false; - } - - return true; -} - -// ------------------------------------------------- -// Metric and Counter Evaluation Utilities -// ------------------------------------------------- - -std::string getRangeDescription( - const std::vector& counterDataImage, - int rangeIndex) { - std::vector descriptionPtrs; - - NVPW_Profiler_CounterData_GetRangeDescriptions_Params getRangeDescParams = { - NVPW_Profiler_CounterData_GetRangeDescriptions_Params_STRUCT_SIZE, nullptr}; - getRangeDescParams.pCounterDataImage = counterDataImage.data(); - getRangeDescParams.rangeIndex = rangeIndex; - - if (!NVPW_CALL( - NVPW_Profiler_CounterData_GetRangeDescriptions(&getRangeDescParams))) { - return ""; - } - - descriptionPtrs.resize(getRangeDescParams.numDescriptions); - getRangeDescParams.ppDescriptions = descriptionPtrs.data(); - - if (!NVPW_CALL( - NVPW_Profiler_CounterData_GetRangeDescriptions(&getRangeDescParams))) { - return ""; - } - - std::string rangeName; - - for (size_t i = 0; i < getRangeDescParams.numDescriptions; i++) { - if (i > 0) { - rangeName.append("/"); - } - rangeName.append(descriptionPtrs[i]); - } - return rangeName; -} - -CuptiProfilerResult evalMetricValues( - const std::string& chipName, - const std::vector& counterDataImage, - const std::vector& metricNames, - bool verbose) { - - if (!counterDataImage.size()) { - LOG(ERROR) << "Counter Data Image is empty!"; - return {}; - } - - NVPW_CUDA_MetricsContext_Create_Params metricsContextCreateParams = { - NVPW_CUDA_MetricsContext_Create_Params_STRUCT_SIZE, nullptr}; - metricsContextCreateParams.pChipName = chipName.c_str(); - if (!NVPW_CALL( - NVPW_CUDA_MetricsContext_Create(&metricsContextCreateParams))) { - return {}; - } - - NVPW_MetricsContext_Destroy_Params metricsContextDestroyParams = { - NVPW_MetricsContext_Destroy_Params_STRUCT_SIZE, nullptr}; - metricsContextDestroyParams.pMetricsContext = - metricsContextCreateParams.pMetricsContext; - SCOPE_EXIT([&]() { - NVPW_MetricsContext_Destroy( - (NVPW_MetricsContext_Destroy_Params*)&metricsContextDestroyParams); - }); - - NVPW_CounterData_GetNumRanges_Params getNumRangesParams = { - NVPW_CounterData_GetNumRanges_Params_STRUCT_SIZE, nullptr}; - getNumRangesParams.pCounterDataImage = counterDataImage.data(); - if (!NVPW_CALL( - NVPW_CounterData_GetNumRanges(&getNumRangesParams))) { - return {}; - } - - // TBD in the future support special chars in metric name - // for now these are default - const bool isolated = true; - - // API takes a 2D array of chars - std::vector metricNamePtrs; - - for (const auto& metric : metricNames) { - metricNamePtrs.push_back(metric.c_str()); - } - - CuptiProfilerResult result{ - .metricNames = metricNames}; - - for (size_t rangeIndex = 0; rangeIndex < getNumRangesParams.numRanges; - ++rangeIndex) { - - CuptiRangeMeasurement rangeData { - .rangeName = getRangeDescription(counterDataImage, rangeIndex)}; - rangeData.values.resize(metricNames.size()); - - // First set Counter data image with current range - NVPW_MetricsContext_SetCounterData_Params setCounterDataParams = { - NVPW_MetricsContext_SetCounterData_Params_STRUCT_SIZE, nullptr}; - - setCounterDataParams.pMetricsContext = - metricsContextCreateParams.pMetricsContext; - setCounterDataParams.pCounterDataImage = counterDataImage.data(); - setCounterDataParams.isolated = isolated; - setCounterDataParams.rangeIndex = rangeIndex; - - NVPW_CALL(NVPW_MetricsContext_SetCounterData(&setCounterDataParams)); - - - // Now we can evaluate GPU metrics - NVPW_MetricsContext_EvaluateToGpuValues_Params evalToGpuParams = { - NVPW_MetricsContext_EvaluateToGpuValues_Params_STRUCT_SIZE, nullptr}; - evalToGpuParams.pMetricsContext = - metricsContextCreateParams.pMetricsContext; - evalToGpuParams.numMetrics = metricNamePtrs.size(); - evalToGpuParams.ppMetricNames = metricNamePtrs.data(); - evalToGpuParams.pMetricValues = rangeData.values.data(); - - if (!NVPW_CALL(NVPW_MetricsContext_EvaluateToGpuValues(&evalToGpuParams))) { - LOG(WARNING) << "Failed to evaluate metris for range : " - << rangeData.rangeName; - continue; - } - - if (verbose) { - for (size_t i = 0; i < metricNames.size(); i++) { - LOG(INFO) << "rangeName: " << rangeData.rangeName - << "\tmetricName: " << metricNames[i] - << "\tgpuValue: " << rangeData.values[i]; - } - } - - result.rangeVals.emplace_back(std::move(rangeData)); - } - - return result; -} - -#else - -bool getProfilerConfigImage( - const std::string& /*chipName*/, - const std::vector& /*metricNames*/, - std::vector& /*configImage*/, - const uint8_t* /*counterAvailabilityImage*/) { - return false; -} - -bool getCounterDataPrefixImage( - const std::string& /*chipName*/, - const std::vector& /*metricNames*/, - std::vector& /*counterDataImagePrefix*/) { - return false; -} - -CuptiProfilerResult evalMetricValues( - const std::string& /*chipName*/, - const std::vector& /*counterDataImage*/, - const std::vector& /*metricNames*/, - bool /*verbose*/) { - return {}; -} - -#endif // cuda version > 10.00 and < 11.04 - -} // namespace nvperf -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h b/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h deleted file mode 100644 index d5dd1b1c1d20b066891f8be679e6d6371d4f4a9b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h +++ /dev/null @@ -1,71 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "Logger.h" - -namespace KINETO_NAMESPACE { - -struct CuptiRangeMeasurement { - std::string rangeName; - std::vector values; -}; - -struct CuptiProfilerResult { - std::vector metricNames; - // rangeName, list values - std::vector rangeVals; -}; - -/* Utilities for CUPTI and NVIDIA PerfWorks Metric API - */ - -#define NVPW_CALL(call) \ - [&]() -> bool { \ - NVPA_Status _status_ = call; \ - if (_status_ != NVPA_STATUS_SUCCESS) { \ - LOG(WARNING) << fmt::format( \ - "function {} failed with error ({})", \ - #call, \ - (int)_status_); \ - return false; \ - } \ - return true; \ - }() - -// fixme - add a results string -// nvpperfGetResultString(_status_, &_errstr_); - -namespace nvperf { - -// Setup CUPTI profiler configuration blob and counter data image prefix -bool getProfilerConfigImage( - const std::string& chipName, - const std::vector& metricNames, - std::vector& configImage, - const uint8_t* counterAvailabilityImage = nullptr); - -// Setup CUPTI profiler configuration blob and counter data image prefix -bool getCounterDataPrefixImage( - const std::string& chipName, - const std::vector& metricNames, - std::vector& counterDataImagePrefix); - -/* NV Perf Metric Evaluation helpers - * - utilities to read binary data and obtain metrics for ranges - */ -CuptiProfilerResult evalMetricValues( - const std::string& chipName, - const std::vector& counterDataImage, - const std::vector& metricNames, - bool verbose = false); - - -} // namespace nvperf -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp deleted file mode 100644 index e5f18ed7b0b70963eb2deab126ff4f7119ed582b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp +++ /dev/null @@ -1,751 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include -#ifdef HAS_CUPTI -#include -#include -#endif // HAS_CUPTI -#include -#include - -#ifdef HAS_CUPTI -#include "cupti_call.h" -#endif - -#include "time_since_epoch.h" -#include "Logger.h" -#include "Demangle.h" - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "CuptiCallbackApiMock.h" -#include "CuptiRangeProfilerApi.h" - -#if HAS_CUPTI_RANGE_PROFILER -#include -#include -#include "cupti_call.h" -#endif // HAS_CUPTI_RANGE_PROFILER - -namespace KINETO_NAMESPACE { - -#if HAS_CUPTI_RANGE_PROFILER -constexpr char kRootUserRangeName[] = "__profile__"; -constexpr int kCallbacksCountToFlush = 500; - -// Should we set Counter availability image ourselves? -// Disabled this right now as this call conflicts with DCGM -// It is not clear why it should conflict except it being a profiler API call -// TODO Revisit -constexpr bool kSetCounterAvail = false; - -// Shared state to track one Cupti Profiler API per Device -namespace { -// per device profiler maps -std::unordered_map profiler_map; -std::unordered_map enable_flag; -std::unordered_map disable_flag; - -std::mutex contextMutex_; -std::unordered_map ctx_to_dev; -std::set active_devices; -} - -// forward declarations -void __trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid); -void __trackCudaKernelLaunch(CUcontext ctx, const char* kernelName); - -/// Helper functions - -// Available raw counters -std::vector getCounterAvailiability(CUcontext cuContext) { - std::vector counterAvailabilityImage; - CUpti_Profiler_GetCounterAvailability_Params getCounterAvailabilityParams = { - CUpti_Profiler_GetCounterAvailability_Params_STRUCT_SIZE, nullptr}; - getCounterAvailabilityParams.ctx = cuContext; - CUPTI_CALL( - cuptiProfilerGetCounterAvailability(&getCounterAvailabilityParams)); - - counterAvailabilityImage.clear(); - counterAvailabilityImage.resize( - getCounterAvailabilityParams.counterAvailabilityImageSize); - - getCounterAvailabilityParams.pCounterAvailabilityImage = - counterAvailabilityImage.data(); - CUPTI_CALL( - cuptiProfilerGetCounterAvailability(&getCounterAvailabilityParams)); - - return counterAvailabilityImage; -} - -std::string getChipName(int deviceId) { - // Get chip name for the cuda device - CUpti_Device_GetChipName_Params getChipNameParams = { - CUpti_Device_GetChipName_Params_STRUCT_SIZE, nullptr}; - - getChipNameParams.deviceIndex = deviceId; - CUPTI_CALL(cuptiDeviceGetChipName(&getChipNameParams)); - - return getChipNameParams.pChipName; -} - -inline uint32_t getDevID(CUcontext ctx) { - uint32_t device_id = UINT32_MAX; - CUPTI_CALL(cuptiGetDeviceId(ctx, &device_id)); - if (device_id == UINT32_MAX) { - LOG(ERROR) << "Could not determine dev id for = " << ctx; - } - return device_id; -} - -// We use CUPTI Callback functions in three ways : -// 1. Track cuda contexts and maintain a list of active GPUs to profile -// 2. Callbacks on kernel launches to track the name of automatic -// ranges that correspond to names of kernels -// 3. Lastly CUPTI profiler has to be enabled on the same thread executing -// the CUDA kernels. We use Callbacks to enable the profiler -// asynchronously from another thread. - -void disableKernelCallbacks(); - -void trackCudaCtx( - CUpti_CallbackDomain /*domain*/, - CUpti_CallbackId cbid, - const CUpti_CallbackData* cbInfo) { - auto *d = reinterpret_cast(cbInfo); - auto ctx = d->context; - uint32_t device_id = getDevID(ctx); - - if (device_id == UINT32_MAX) { - return; - } - - __trackCudaCtx(ctx, device_id, cbid); -} - -void __trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid) { - std::lock_guard g(contextMutex_); - if (cbid == CUPTI_CBID_RESOURCE_CONTEXT_CREATED) { - VLOG(0) << "CUPTI Profiler observed CUDA Context created = " - << ctx << " device id = " << device_id; - active_devices.insert(device_id); - if constexpr (kSetCounterAvail) { - if (active_devices.size() == 1) { - CuptiRBProfilerSession::setCounterAvailabilityImage( - getCounterAvailiability(ctx)); - } - } - ctx_to_dev[ctx] = device_id; - - } else if (cbid == CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING) { - VLOG(0) << "CUPTI Profiler observed CUDA Context destroyed = " - << ctx << " device id = " << device_id; - auto it = active_devices.find(device_id); - if (it != active_devices.end()) { - active_devices.erase(it); - ctx_to_dev.erase(ctx); - } - } -} - -void trackCudaKernelLaunch( - CUpti_CallbackDomain /*domain*/, - CUpti_CallbackId /*cbid*/, - const CUpti_CallbackData* cbInfo) { - VLOG(1) << " Trace : Callback name = " - << (cbInfo->symbolName ? cbInfo->symbolName: "") - << " context ptr = " << cbInfo->context; - auto ctx = cbInfo->context; - // should be in CUPTI_API_ENTER call site - if (cbInfo->callbackSite != CUPTI_API_ENTER) { - return; - } - __trackCudaKernelLaunch(ctx, cbInfo->symbolName); -} - -void __trackCudaKernelLaunch( - CUcontext ctx, - const char* kernelName) { - VLOG(0) << " Tracking kernel name = " << (kernelName ? kernelName : "") - << " context ptr = " << ctx; - - uint32_t device_id = 0; - auto it = ctx_to_dev.find(ctx); - if (it == ctx_to_dev.end()) { - // Warning here could be too noisy - VLOG(0) << " Could not find corresponding device to ctx = " << ctx; - return; - } else { - device_id = it->second; - } - - auto pit = profiler_map.find(device_id); - if (pit == profiler_map.end() || pit->second == nullptr) { - return; - } - auto profiler = pit->second; - - if (enable_flag[device_id]) { - LOG(INFO) << "Callback handler is enabling cupti profiler"; - profiler->startAndEnable(); - enable_flag[device_id] = false; - - } else if (disable_flag[device_id]) { - LOG(INFO) << "Callback handler is disabling cupti profiler"; - profiler->disableAndStop(); - return; - } - - if (profiler->curRange_ == CUPTI_AutoRange) { - profiler->logKernelName(kernelName ? kernelName : "__missing__"); - } - - /* TODO add per kernel time logging - if (measure_per_kernel) { - profiler->kernelStartTs_.push_back( - std::chrono::high_resolution_clock::now()); - } - */ - - // periodically flush profiler data from GPU - if (profiler->numCallbacks_ % kCallbacksCountToFlush == 0) { - profiler->flushCounterData(); - } - profiler->numCallbacks_++; -} - -void enableKernelCallbacks() { - auto& cbapi = CuptiCallbackApi::singleton(); - bool status = cbapi.enableCallback( - CUPTI_CB_DOMAIN_RUNTIME_API, - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); - if (!status) { - LOG(WARNING) << "CUPTI Range Profiler unable to " - << "enable cuda kernel launch callback"; - return; - } - LOG(INFO) << "CUPTI Profiler kernel callbacks enabled"; -} - -void disableKernelCallbacks() { - auto& cbapi = CuptiCallbackApi::singleton(); - bool status = cbapi.disableCallback( - CUPTI_CB_DOMAIN_RUNTIME_API, - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); - if (!status) { - LOG(WARNING) << "CUPTI Range Profiler unable to " - << "disable cuda kernel launch callback"; - return; - } - LOG(INFO) << "CUPTI Profiler kernel callbacks disabled"; -} - -// static -std::set CuptiRBProfilerSession::getActiveDevices() { - std::lock_guard g(contextMutex_); - return active_devices; -} - -// static -void CuptiRBProfilerSession::initCupti() { - CUpti_Profiler_Initialize_Params profilerInitializeParams = { - CUpti_Profiler_Initialize_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerInitialize(&profilerInitializeParams)); -} - -// static -void CuptiRBProfilerSession::deInitCupti() { - CUpti_Profiler_DeInitialize_Params profilerDeInitializeParams = { - CUpti_Profiler_DeInitialize_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerDeInitialize(&profilerDeInitializeParams)); -} - -// static -void CuptiRBProfilerSession::staticInit() { - CuptiRBProfilerSession::initCupti(); - - // Register CUPTI callbacks - auto& cbapi = CuptiCallbackApi::singleton(); - CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RESOURCE; - bool status = cbapi.registerCallback( - domain, CuptiCallbackApi::RESOURCE_CONTEXT_CREATED, trackCudaCtx); - status = status && cbapi.registerCallback( - domain, CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED, trackCudaCtx); - status = status && cbapi.enableCallback( - domain, CUPTI_CBID_RESOURCE_CONTEXT_CREATED); - status = status && cbapi.enableCallback( - domain, CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); - - if (!status) { - LOG(WARNING) << "CUPTI Range Profiler unable to attach cuda context " - << "create and destroy callbacks"; - CUPTI_CALL(cbapi.getCuptiStatus()); - return; - } - - domain = CUPTI_CB_DOMAIN_RUNTIME_API; - status = cbapi.registerCallback( - domain, CuptiCallbackApi::CUDA_LAUNCH_KERNEL, trackCudaKernelLaunch); - - if (!status) { - LOG(WARNING) << "CUPTI Range Profiler unable to attach cuda kernel " - << "launch callback"; - return; - } -} - -// static -std::vector& CuptiRBProfilerSession::counterAvailabilityImage() { - static std::vector counterAvailabilityImage_; - return counterAvailabilityImage_; -} - - -// Setup the profiler sessions -CuptiRBProfilerSession::CuptiRBProfilerSession( - const std::vector& metricNames, - int deviceId, - int maxRanges, - int numNestingLevels, - CUcontext cuContext) - : metricNames_(metricNames), - chipName_(getChipName(deviceId)), - deviceId_(deviceId), - maxRanges_(maxRanges), - numNestingLevels_(numNestingLevels), - cuContext_(cuContext) { - CuptiRBProfilerSession::initCupti(); - - LOG(INFO) << "Initializing CUPTI profiler session : device = " << deviceId - << " chip = " << chipName_; - /* Generate configuration for metrics, this can also be done offline*/ - NVPW_InitializeHost_Params initializeHostParams = { - NVPW_InitializeHost_Params_STRUCT_SIZE, nullptr}; - NVPW_CALL(NVPW_InitializeHost(&initializeHostParams)); - - if (metricNames.size()) { - if (!nvperf::getProfilerConfigImage( - chipName_, - metricNames, - configImage, - CuptiRBProfilerSession::counterAvailabilityImage().data())) { - LOG(ERROR) << "Failed to create configImage or counterDataImagePrefix"; - return; - } - if (!nvperf::getCounterDataPrefixImage( - chipName_, - metricNames, - counterDataImagePrefix)) { - LOG(ERROR) << "Failed to create counterDataImagePrefix"; - return; - } - } else { - LOG(ERROR) << "No metrics provided to profile"; - return; - } - - if (!createCounterDataImage()) { - LOG(ERROR) << "Failed to create counterDataImage"; - return; - } - - LOG(INFO) << "Size of structs\n" - << " config image size = " << configImage.size() << " B" - << " counter data image prefix = " - << counterDataImagePrefix.size() << " B" - << " counter data image size = " << counterDataImage.size() / 1024 - << " KB" - << " counter sb image size = " - << counterDataScratchBuffer.size() << " B"; - - beginPassParams_ = {CUpti_Profiler_BeginPass_Params_STRUCT_SIZE, nullptr}; - endPassParams_ = {CUpti_Profiler_EndPass_Params_STRUCT_SIZE, nullptr}; - - initSuccess_ = true; - profiler_map[deviceId] = this; -} - -// used in unittests only -CuptiRBProfilerSession::CuptiRBProfilerSession(int deviceId, CUcontext ctx) - : deviceId_(deviceId), cuContext_(ctx) { - initSuccess_ = true; - profiler_map[deviceId] = this; -} - -void CuptiRBProfilerSession::startInternal( - CUpti_ProfilerRange profilerRange, - CUpti_ProfilerReplayMode profilerReplayMode) { - LOG(INFO) << "Starting profiler session: profiler range = " - << ((profilerRange == CUPTI_AutoRange) ? "autorange" : "userrange") - << " replay mode = " - << ((profilerReplayMode == CUPTI_KernelReplay) ? "kernel" : "user"); - if (!initSuccess_) { - LOG(WARNING) << __func__ << "() bailing out since initialization failed"; - return; - } - - if (cuContext_ == nullptr) { - for (const auto& it : ctx_to_dev) { - if (it.second == deviceId_) { - cuContext_ = it.first; - break; - } - } - LOG(INFO) << " Cupti Profiler using CUDA context = " << cuContext_; - } - - profilerStartTs_ = std::chrono::high_resolution_clock::now(); - curRange_ = profilerRange; - curReplay_ = profilerReplayMode; - - CUpti_Profiler_BeginSession_Params beginSessionParams = { - CUpti_Profiler_BeginSession_Params_STRUCT_SIZE, nullptr}; - - beginSessionParams.ctx = cuContext_; - beginSessionParams.counterDataImageSize = counterDataImage.size(); - beginSessionParams.pCounterDataImage = counterDataImage.data(); - beginSessionParams.counterDataScratchBufferSize = - counterDataScratchBuffer.size(); - beginSessionParams.pCounterDataScratchBuffer = counterDataScratchBuffer.data(); - beginSessionParams.range = profilerRange; - beginSessionParams.replayMode = profilerReplayMode; - beginSessionParams.maxRangesPerPass = maxRanges_; - beginSessionParams.maxLaunchesPerPass = maxRanges_; - - auto status = CUPTI_CALL(cuptiProfilerBeginSession(&beginSessionParams)); - if (status != CUPTI_SUCCESS) { - LOG(WARNING) << "Failed to start CUPTI profiler"; - initSuccess_ = false; - return; - } - - // Set counter configuration - CUpti_Profiler_SetConfig_Params setConfigParams = { - CUpti_Profiler_SetConfig_Params_STRUCT_SIZE, nullptr}; - - setConfigParams.ctx = cuContext_; - setConfigParams.pConfig = configImage.data(); - setConfigParams.configSize = configImage.size(); - setConfigParams.passIndex = 0; - setConfigParams.minNestingLevel = 1; - setConfigParams.numNestingLevels = numNestingLevels_; - status = CUPTI_CALL(cuptiProfilerSetConfig(&setConfigParams)); - - if (status != CUPTI_SUCCESS) { - LOG(WARNING) << "Failed to configure CUPTI profiler"; - initSuccess_ = false; - return; - } - profilerInitDoneTs_ = std::chrono::high_resolution_clock::now(); - - if (curRange_ == CUPTI_AutoRange) { - enableKernelCallbacks(); - } - profilingActive_ = true; -} - -void CuptiRBProfilerSession::stop() { - if (!initSuccess_) { - LOG(WARNING) << __func__ << "() bailing out since initialization failed"; - return; - } - LOG(INFO) << "Stop profiler session on device = " << deviceId_; - - CUpti_Profiler_UnsetConfig_Params unsetConfigParams = { - CUpti_Profiler_UnsetConfig_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerUnsetConfig(&unsetConfigParams)); - - CUpti_Profiler_EndSession_Params endSessionParams = { - CUpti_Profiler_EndSession_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerEndSession(&endSessionParams)); - - disableKernelCallbacks(); - - profilerStopTs_ = std::chrono::high_resolution_clock::now(); - profilingActive_ = false; -} - -void CuptiRBProfilerSession::beginPass() { - if (!initSuccess_) { - LOG(WARNING) << __func__ << "() bailing out since initialization failed"; - return; - } - CUPTI_CALL(cuptiProfilerBeginPass(&beginPassParams_)); -} - -bool CuptiRBProfilerSession::endPass() { - if (!initSuccess_) { - LOG(WARNING) << __func__ << "() bailing out since initialization failed"; - return true; - } - CUPTI_CALL(cuptiProfilerEndPass(&endPassParams_)); - return endPassParams_.allPassesSubmitted; -} - -void CuptiRBProfilerSession::flushCounterData() { - LOG(INFO) << "Flushing counter data on device = " << deviceId_; - CUpti_Profiler_FlushCounterData_Params flushCounterDataParams = { - CUpti_Profiler_FlushCounterData_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerFlushCounterData(&flushCounterDataParams)); -} - -/// Enable and disable the profiler -void CuptiRBProfilerSession::enable() { - if (!initSuccess_) { - LOG(WARNING) << __func__ << "() bailing out since initialization failed"; - return; - } - CUpti_Profiler_EnableProfiling_Params enableProfilingParams = { - CUpti_Profiler_EnableProfiling_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerEnableProfiling(&enableProfilingParams)); -} - -void CuptiRBProfilerSession::disable() { - if (!initSuccess_) { - LOG(WARNING) << __func__ << "() bailing out since initialization failed"; - return; - } - CUpti_Profiler_DisableProfiling_Params disableProfilingParams = { - CUpti_Profiler_DisableProfiling_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerDisableProfiling(&disableProfilingParams)); -} - -/// User range based profiling -void CuptiRBProfilerSession::pushRange(const std::string& rangeName) { - LOG(INFO) << " CUPTI pushrange ( " << rangeName << " )"; - CUpti_Profiler_PushRange_Params pushRangeParams = { - CUpti_Profiler_PushRange_Params_STRUCT_SIZE, nullptr}; - pushRangeParams.pRangeName = rangeName.c_str(); - CUPTI_CALL(cuptiProfilerPushRange(&pushRangeParams)); -} - -void CuptiRBProfilerSession::popRange() { - LOG(INFO) << " CUPTI pop range"; - CUpti_Profiler_PopRange_Params popRangeParams = { - CUpti_Profiler_PopRange_Params_STRUCT_SIZE, nullptr}; - CUPTI_CALL(cuptiProfilerPopRange(&popRangeParams)); -} - -void CuptiRBProfilerSession::startAndEnable() { - startInternal(curRange_, curReplay_); - if (curReplay_ == CUPTI_UserReplay) { - beginPass(); - } - enable(); - if (curRange_ == CUPTI_UserRange) { - pushRange(kRootUserRangeName); - } - enable_flag[deviceId_] = false; -} - -void CuptiRBProfilerSession::disableAndStop() { - if (curRange_ == CUPTI_UserRange) { - popRange(); - } - disable(); - if (curReplay_ == CUPTI_UserReplay) { - endPass(); - flushCounterData(); - } - stop(); - disable_flag[deviceId_] = false; -} - -void CuptiRBProfilerSession::asyncStartAndEnable( - CUpti_ProfilerRange profilerRange, - CUpti_ProfilerReplayMode profilerReplayMode) { - LOG(INFO) << "Starting CUPTI profiler asynchronously on device = " - << deviceId_ << " profiler range = " - << ((profilerRange == CUPTI_AutoRange) ? "autorange" : "userrange") - << " replay mode = " - << ((profilerReplayMode == CUPTI_KernelReplay) ? "kernel" : "user"); - curReplay_ = profilerReplayMode; - curRange_ = profilerRange; - enable_flag[deviceId_] = true; - enableKernelCallbacks(); -} - -void CuptiRBProfilerSession::asyncDisableAndStop() { - LOG(INFO) << "Stopping CUPTI profiler asynchronously on device = " - << deviceId_ << " cu context = " << cuContext_; - disable_flag[deviceId_] = true; -} - - -CuptiProfilerResult CuptiRBProfilerSession::evaluateMetrics( - bool verbose) { - if (!initSuccess_) { - LOG(WARNING) << "Profiling failed, no results to return"; - return {}; - } - if (profilingActive_) { - disableAndStop(); - } - - LOG(INFO) << "Total kernels logged = " << kernelNames_.size(); - if (verbose) { - for (const auto& kernel : kernelNames_) { - std::cout << demangle(kernel) << std::endl; - } - LOG(INFO) << "Profiler Range data : "; - } - - auto results = nvperf::evalMetricValues( - chipName_, counterDataImage, metricNames_, verbose /*verbose*/); - - // profiler end-end duration - auto duration_ms = std::chrono::duration_cast( - profilerStopTs_ - profilerStartTs_); - - auto init_dur_ms = std::chrono::duration_cast( - profilerInitDoneTs_ - profilerStartTs_); - LOG(INFO) << "Total profiler time = " << duration_ms.count() << " ms"; - LOG(INFO) << "Total profiler init time = " << init_dur_ms.count() << " ms"; - - return results; -} - -std::unique_ptr CuptiRBProfilerSession::getProfilerTraceSpan() { - return std::make_unique( - timeSinceEpoch(profilerStartTs_), - timeSinceEpoch(profilerStopTs_), - "__cupti_profiler__" - ); -} - -void CuptiRBProfilerSession::saveCounterData( - const std::string& /*CounterDataFileName*/, - const std::string& /*CounterDataSBFileName*/) { - /* TBD write binary files for counter data and counter scratch buffer */ -} - -/// Setup counter data -bool CuptiRBProfilerSession::createCounterDataImage() { - CUpti_Profiler_CounterDataImageOptions counterDataImageOptions; - counterDataImageOptions.pCounterDataPrefix = counterDataImagePrefix.data(); - counterDataImageOptions.counterDataPrefixSize = counterDataImagePrefix.size(); - counterDataImageOptions.maxNumRanges = maxRanges_; - counterDataImageOptions.maxNumRangeTreeNodes = maxRanges_; - counterDataImageOptions.maxRangeNameLength = 64; - - // Calculate size of counter data image - CUpti_Profiler_CounterDataImage_CalculateSize_Params calculateSizeParams = { - CUpti_Profiler_CounterDataImage_CalculateSize_Params_STRUCT_SIZE, nullptr}; - calculateSizeParams.pOptions = &counterDataImageOptions; - calculateSizeParams.sizeofCounterDataImageOptions = - CUpti_Profiler_CounterDataImageOptions_STRUCT_SIZE; - - CUPTI_CALL( - cuptiProfilerCounterDataImageCalculateSize(&calculateSizeParams)); - counterDataImage.resize(calculateSizeParams.counterDataImageSize); - - // Initialize counter data image - CUpti_Profiler_CounterDataImage_Initialize_Params initializeParams = { - CUpti_Profiler_CounterDataImage_Initialize_Params_STRUCT_SIZE, nullptr}; - initializeParams.sizeofCounterDataImageOptions = - CUpti_Profiler_CounterDataImageOptions_STRUCT_SIZE; - initializeParams.pOptions = &counterDataImageOptions; - initializeParams.counterDataImageSize = - calculateSizeParams.counterDataImageSize; - initializeParams.pCounterDataImage = counterDataImage.data(); - CUPTI_CALL(cuptiProfilerCounterDataImageInitialize(&initializeParams)); - - // Calculate counter Scratch Buffer size - CUpti_Profiler_CounterDataImage_CalculateScratchBufferSize_Params - scratchBufferSizeParams = { - CUpti_Profiler_CounterDataImage_CalculateScratchBufferSize_Params_STRUCT_SIZE, nullptr}; - - scratchBufferSizeParams.counterDataImageSize = - calculateSizeParams.counterDataImageSize; - scratchBufferSizeParams.pCounterDataImage = - initializeParams.pCounterDataImage; - CUPTI_CALL(cuptiProfilerCounterDataImageCalculateScratchBufferSize( - &scratchBufferSizeParams)); - - counterDataScratchBuffer.resize( - scratchBufferSizeParams.counterDataScratchBufferSize); - - // Initialize scratch buffer - CUpti_Profiler_CounterDataImage_InitializeScratchBuffer_Params - initScratchBufferParams = { - CUpti_Profiler_CounterDataImage_InitializeScratchBuffer_Params_STRUCT_SIZE, nullptr}; - - initScratchBufferParams.counterDataImageSize = - calculateSizeParams.counterDataImageSize; - - initScratchBufferParams.pCounterDataImage = - initializeParams.pCounterDataImage; - initScratchBufferParams.counterDataScratchBufferSize = - scratchBufferSizeParams.counterDataScratchBufferSize; - initScratchBufferParams.pCounterDataScratchBuffer = - counterDataScratchBuffer.data(); - - CUPTI_CALL(cuptiProfilerCounterDataImageInitializeScratchBuffer( - &initScratchBufferParams)); - - return true; -} - -#elif defined(HAS_CUPTI) - -// Create empty stubs for the API when CUPTI is not present. -CuptiRBProfilerSession::CuptiRBProfilerSession( - const std::vector& metricNames, - int deviceId, - int maxRanges, - int numNestingLevels, - CUcontext cuContext) - : metricNames_(metricNames), - deviceId_(deviceId), - maxRanges_(maxRanges), - numNestingLevels_(numNestingLevels), - cuContext_(cuContext) {} -void CuptiRBProfilerSession::stop() {} -void CuptiRBProfilerSession::enable() {} -void CuptiRBProfilerSession::disable() {} -void CuptiRBProfilerSession::beginPass() {} -bool CuptiRBProfilerSession::endPass() { return true; } -void CuptiRBProfilerSession::flushCounterData() {} -void CuptiRBProfilerSession::pushRange(const std::string& /*rangeName*/) {} -void CuptiRBProfilerSession::popRange() {} -void CuptiRBProfilerSession::asyncStartAndEnable( - CUpti_ProfilerRange /*profilerRange*/, - CUpti_ProfilerReplayMode /*profilerReplayMode*/) {} -void CuptiRBProfilerSession::asyncDisableAndStop() {} -CuptiProfilerResult CuptiRBProfilerSession::evaluateMetrics(bool verbose) { - static CuptiProfilerResult res; - return res; -}; -void CuptiRBProfilerSession::saveCounterData( - const std::string& /*CounterDataFileName*/, - const std::string& /*CounterDataSBFileName*/) {} -void CuptiRBProfilerSession::initCupti() {} -void CuptiRBProfilerSession::deInitCupti() {} -void CuptiRBProfilerSession::staticInit() {} -bool CuptiRBProfilerSession::createCounterDataImage() { return true; } -void CuptiRBProfilerSession::startInternal( - CUpti_ProfilerRange /*profilerRange*/, - CUpti_ProfilerReplayMode /*profilerReplayMode*/) {} -std::vector& CuptiRBProfilerSession::counterAvailabilityImage() { - static std::vector _vec; - return _vec; -} -#endif // HAS_CUPTI_RANGE_PROFILER - -namespace testing { - -void trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid) { -#if HAS_CUPTI_RANGE_PROFILER - __trackCudaCtx(ctx, device_id, cbid); -#endif // HAS_CUPTI_RANGE_PROFILER -} - -void trackCudaKernelLaunch(CUcontext ctx, const char* kernelName) { -#if HAS_CUPTI_RANGE_PROFILER - __trackCudaKernelLaunch(ctx, kernelName); -#endif // HAS_CUPTI_RANGE_PROFILER -} - -} // namespace testing -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h deleted file mode 100644 index 98a0b3ea5f4850dfa060e4e86d5ebf210692db1a..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h +++ /dev/null @@ -1,220 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#ifdef HAS_CUPTI -#include -#include -// Using CUDA 11 and above due to usage of API: cuptiProfilerGetCounterAvailability. -#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000 && CUDART_VERSION < 11040 && CUDA_VERSION >= 11000 -#define HAS_CUPTI_RANGE_PROFILER 1 -#endif // CUDART_VERSION > 10.00 and < 11.04 && CUDA_VERSION >= 11.00 -#endif // HAS_CUPTI - -#if HAS_CUPTI_RANGE_PROFILER -#include -#include -#include -#else -using CUpti_ProfilerRange = enum -{ - CUPTI_AutoRange, - CUPTI_UserRange, -}; - -using CUpti_ProfilerReplayMode = enum -{ - CUPTI_KernelReplay, - CUPTI_UserReplay, -}; -#endif // HAS_CUPTI_RANGE_PROFILER - -#include -#include -#include -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "TraceSpan.h" -#include "CuptiCallbackApi.h" -#include "CuptiNvPerfMetric.h" - -/* Cupti Range based profiler session - * See : https://docs.nvidia.com/cupti/Cupti/r_main.html#r_profiler - */ - -namespace KINETO_NAMESPACE { - -class CuptiRBProfilerSession { - public: - // Initialize and configure CUPTI Profiler counters. - // - Metric names must be provided as string vector. - // - Supported values by CUPTI can be found at - - // https://docs.nvidia.com/cupti/Cupti/r_main.html#r_host_metrics_api - explicit CuptiRBProfilerSession( - const std::vector& metricNames, - int deviceId, - int maxRanges, - int numNestingLevels = 1, - CUcontext cuContext = 0); - - virtual ~CuptiRBProfilerSession() = default; - - // Start profiling session - // This function has to be called from the CPU thread running - // the CUDA context. If this is not the case asyncStartAndEnable() - // can be used - void start( - CUpti_ProfilerRange profilerRange = CUPTI_AutoRange, - CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_KernelReplay) { - startInternal(profilerRange, profilerReplayMode); - } - - // Stop profiling session - virtual void stop(); - - virtual void enable(); - virtual void disable(); - - // Profiler passes - // GPU hardware has limited performance monitoring resources - // the CUPTI profiler may need to run multiple passes to collect - // data for a given range - // If we use kernel replay model the kernels are automatically replayed - // else, you can use the beginPass() and endPass() functions below - // for user to manage the replays - - // starts a profiler pass with given kernels in between - virtual void beginPass(); - - // end a profiler pass with given kernels in between - // returns true if no more passes are required - virtual bool endPass(); - - // flushes the counter data - required if you use user replay - virtual void flushCounterData(); - - // Each pass can contain multiple of ranges - // metrics configured in a pass are collected per each range-stack. - virtual void pushRange(const std::string& rangeName); - virtual void popRange(); - - // utilities for common operations - void startAndEnable(); - void disableAndStop(); - - // Async APIs : these will can be called from another thread - // outside the CUDA context being profiled - void asyncStartAndEnable( - CUpti_ProfilerRange profilerRange = CUPTI_AutoRange, - CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_KernelReplay); - void asyncDisableAndStop(); - - void printMetrics() { - evaluateMetrics(true); - } - - std::unique_ptr getProfilerTraceSpan(); - - virtual CuptiProfilerResult evaluateMetrics(bool verbose = false); - - void saveCounterData( - const std::string& CounterDataFileName, - const std::string& CounterDataSBFileName); - - // This is not thread safe so please only call after - // profiling has stopped - const std::vector& getKernelNames() const { - return kernelNames_; - } - - int deviceId() const { - return deviceId_; - } - - bool profilingActive() const { - return profilingActive_; - } - - static std::set getActiveDevices(); - - static void initCupti(); - - static void deInitCupti(); - - static void staticInit(); - - static void setCounterAvailabilityImage(std::vector img) { - counterAvailabilityImage() = img; - } - protected: - CuptiRBProfilerSession(int deviceId, CUcontext ctx); - - virtual void startInternal( - CUpti_ProfilerRange profilerRange, - CUpti_ProfilerReplayMode profilerReplayMode); - - CUpti_ProfilerRange curRange_ = CUPTI_AutoRange; - CUpti_ProfilerReplayMode curReplay_ = CUPTI_KernelReplay; - - private: - - bool createCounterDataImage(); - - - // log kernel name that used with callbacks - void logKernelName(const char* kernel) { - std::lock_guard lg(kernelNamesMutex_); - kernelNames_.emplace_back(kernel); - } - - std::vector metricNames_; - std::string chipName_; - - uint32_t deviceId_ = 0; - int maxRanges_; - int numNestingLevels_; - CUcontext cuContext_; - - - // data buffers for configuration and counter data collection - std::vector counterDataImagePrefix; - std::vector configImage; - std::vector counterDataImage; - std::vector counterDataScratchBuffer; - - std::chrono::time_point profilerStartTs_; - std::chrono::time_point - profilerInitDoneTs_; - std::chrono::time_point profilerStopTs_; - - std::mutex kernelNamesMutex_; - // raw kernel names (not demangled) - std::vector kernelNames_; - - uint32_t numCallbacks_ = 0; - - static std::vector& counterAvailabilityImage(); - -#if HAS_CUPTI_RANGE_PROFILER - CUpti_Profiler_BeginPass_Params beginPassParams_; - CUpti_Profiler_EndPass_Params endPassParams_; -#endif - - bool initSuccess_ = false; - bool profilingActive_ = false; - - friend void __trackCudaKernelLaunch(CUcontext ctx, const char* kernelName); -}; - -// called directly only in unit tests -namespace testing { - -void trackCudaCtx(CUcontext ctx, uint32_t device_id, CUpti_CallbackId cbid); -void trackCudaKernelLaunch(CUcontext ctx, const char* kernelName); - -} // namespace testing - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp deleted file mode 100644 index 04b1ad0cb3f807cf87d32bc03de0ca9b552b0063..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include - -#include -#include - -#include -#include - -using namespace std::chrono; - -namespace KINETO_NAMESPACE { - -// number of ranges affect the size of counter data binary used by -// the CUPTI Profiler. these defaults can be tuned -constexpr int KMaxAutoRanges = 1500; // supports 1500 kernels -constexpr int KMaxUserRanges = 10; // enable upto 10 sub regions marked by user - -constexpr char kCuptiProfilerMetricsKey[] = "CUPTI_PROFILER_METRICS"; -constexpr char kCuptiProfilerPerKernelKey[] = "CUPTI_PROFILER_ENABLE_PER_KERNEL"; -constexpr char kCuptiProfilerMaxRangesKey[] = "CUPTI_PROFILER_MAX_RANGES"; - -CuptiRangeProfilerConfig::CuptiRangeProfilerConfig(Config& cfg) - : parent_(&cfg), - cuptiProfilerPerKernel_(false), - cuptiProfilerMaxRanges_(0) {} - -bool CuptiRangeProfilerConfig::handleOption(const std::string& name, std::string& val) { - VLOG(0) << " handling : " << name << " = " << val; - // Cupti Range based Profiler configuration - if (!name.compare(kCuptiProfilerMetricsKey)) { - activitiesCuptiMetrics_ = splitAndTrim(val, ','); - } else if (!name.compare(kCuptiProfilerPerKernelKey)) { - cuptiProfilerPerKernel_ = toBool(val); - } else if (!name.compare(kCuptiProfilerMaxRangesKey)) { - cuptiProfilerMaxRanges_ = toInt64(val); - } else { - return false; - } - return true; -} - -void CuptiRangeProfilerConfig::setDefaults() { - if (activitiesCuptiMetrics_.size() > 0 && cuptiProfilerMaxRanges_ == 0) { - cuptiProfilerMaxRanges_ = - cuptiProfilerPerKernel_ ? KMaxAutoRanges : KMaxUserRanges; - } -} - -void CuptiRangeProfilerConfig::printActivityProfilerConfig(std::ostream& s) const { - if (activitiesCuptiMetrics_.size() > 0) { - s << "Cupti Profiler metrics : " - << fmt::format("{}", fmt::join(activitiesCuptiMetrics_, ", ")) << std::endl; - s << "Cupti Profiler measure per kernel : " - << cuptiProfilerPerKernel_ << std::endl; - s << "Cupti Profiler max ranges : " << cuptiProfilerMaxRanges_ << std::endl; - } -} - -void CuptiRangeProfilerConfig::registerFactory() { - Config::addConfigFactory( - kCuptiProfilerConfigName, - [](Config& cfg) { return new CuptiRangeProfilerConfig(cfg); }); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h b/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h deleted file mode 100644 index 549b8a4e8b40c66b59bae974eb87c7f64967344e..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h +++ /dev/null @@ -1,86 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include "Config.h" - -#include -#include -#include -#include - -namespace KINETO_NAMESPACE { - -constexpr char kCuptiProfilerConfigName[] = "cupti_rb_profiler"; - -class CuptiRangeProfilerConfig : public AbstractConfig { - public: - bool handleOption(const std::string& name, std::string& val) override; - - void validate( - const std::chrono::time_point& - fallbackProfileStartTime) override {} - - static CuptiRangeProfilerConfig& get(const Config& cfg) { - return dynamic_cast(cfg.feature( - kCuptiProfilerConfigName)); - } - - Config& parent() const { - return *parent_; - } - - std::vector activitiesCuptiMetrics() const { - return activitiesCuptiMetrics_; - } - - bool cuptiProfilerPerKernel() const { - return cuptiProfilerPerKernel_; - } - - int64_t cuptiProfilerMaxRanges() const { - return cuptiProfilerMaxRanges_; - } - - void setSignalDefaults() override { - setDefaults(); - } - - void setClientDefaults() override { - setDefaults(); - } - - void printActivityProfilerConfig(std::ostream& s) const override; - - static void registerFactory(); - protected: - AbstractConfig* cloneDerived(AbstractConfig& parent) const override { - CuptiRangeProfilerConfig* clone = new CuptiRangeProfilerConfig(*this); - clone->parent_ = dynamic_cast(&parent); - return clone; - } - - private: - CuptiRangeProfilerConfig() = delete; - explicit CuptiRangeProfilerConfig(Config& parent); - explicit CuptiRangeProfilerConfig( - const CuptiRangeProfilerConfig& other) = default; - - // some defaults will depend on other configuration - void setDefaults(); - - // Associated Config object - Config* parent_; - - // Counter metrics exposed via CUPTI Profiler API - std::vector activitiesCuptiMetrics_; - - // Collect profiler metrics per kernel - autorange made - bool cuptiProfilerPerKernel_{false}; - - // max number of ranges to configure the profiler for. - // this has to be set before hand to reserve space for the output - int64_t cuptiProfilerMaxRanges_ = 0; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h b/plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h deleted file mode 100644 index 9b0ed92863648824a57ce8193ddc16d7cf23622e..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h +++ /dev/null @@ -1,27 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include - -namespace KINETO_NAMESPACE { - -class DaemonConfigLoader { - public: - virtual ~DaemonConfigLoader() {} - - // Return the base config from the daemon - virtual std::string readBaseConfig() = 0; - - // Return a configuration string from the daemon, if one has been posted. - virtual std::string readOnDemandConfig(bool events, bool activities) = 0; - - // Returns the number of tracked contexts for this device. The daemon has a - // global view. If an unexpedted error occurs, return -1. - virtual int gpuContextCount(uint32_t device) = 0; - - virtual void setCommunicationFabric(bool enabled) = 0; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/Demangle.cpp b/plugins/tensorboard-plugins/libkineto/src/Demangle.cpp deleted file mode 100644 index f84f0b8ec36f621061cb1e8bb8dd948cb8aed7b3..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/Demangle.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "Demangle.h" - -#ifndef _MSC_VER -#include -#endif -#include -#include - -namespace KINETO_NAMESPACE { - -static constexpr int kMaxSymbolSize = 1024; - -std::string demangle(const char* name) { -#ifndef _MSC_VER - if (!name) { - return ""; - } - - if (strlen(name) > kMaxSymbolSize) { - return name; - } - - int status; - size_t len = 0; - char* demangled = abi::__cxa_demangle(name, nullptr, &len, &status); - if (status != 0) { - return name; - } - std::string res(demangled); - // The returned buffer must be freed! - free(demangled); - return res; -#else - // TODO: demangling on Windows - if (!name) { - return ""; - } else { - return name; - } -#endif -} - -std::string demangle(const std::string& name) { - return demangle(name.c_str()); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/Demangle.h b/plugins/tensorboard-plugins/libkineto/src/Demangle.h deleted file mode 100644 index 6dcf0776f1abf30e7e3614272fa02f6bae1bdf35..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/Demangle.h +++ /dev/null @@ -1,12 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -namespace KINETO_NAMESPACE { - -std::string demangle(const char* name); -std::string demangle(const std::string& name); - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp b/plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp deleted file mode 100644 index dbf2755238974392ff6205f05a5c80a1733bf2ee..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp +++ /dev/null @@ -1,635 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "EventProfiler.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "CuptiEventApi.h" -#include "Logger.h" - -using namespace std::chrono; -using std::accumulate; -using std::endl; -using std::map; -using std::ostream; -using std::string; -using std::unique_ptr; -using std::vector; - -namespace KINETO_NAMESPACE { - -static std::mutex& logMutex() { - static std::mutex instance; - return instance; -} - -// --------------------------------------------------------------------- -// class Event -// --------------------------------------------------------------------- - -// Compute domain instance percentiles -PercentileList& Event::percentiles( - PercentileList& pcs, - const SampleSlice& slice) const { - vector instance_values; - instance_values.reserve(instanceCount); - for (int i = 0; i < instanceCount; i++) { - instance_values.push_back(sumInstance(i, slice)); - } - return KINETO_NAMESPACE::percentiles(instance_values, pcs); -} - -// Add up all samples for a given domain instance -int64_t Event::sumInstance(int i, const SampleSlice& slice) const { - auto r = toIdxRange(slice); - auto start = samples_.cbegin(); - std::advance(start, r.first); - auto end = start; - std::advance(end, r.second); - return accumulate(start, end, 0ul, [i](int64_t a, const Sample& b) { - return a + b.second[i]; - }); -} - -// Add up all samples across all domain instances -int64_t Event::sumAll(const SampleSlice& slice) const { - int64_t res = 0; - for (int i = 0; i < instanceCount; i++) { - res += sumInstance(i, slice); - } - return res; -} - -// Print raw sample values for all domains -void Event::printSamples(ostream& s, CUdevice device) const { - // Don't mess up output with interleaved lines - // Probably OK to reuse logMutex() here since this is - // used for debugging, but need to keep an eye on it. - std::lock_guard lock(logMutex()); - s << "Device " << device << " " << name << ":" << endl; - for (const auto& sample : samples_) { - const auto& vals = sample.second; - for (int64_t val : vals) { - s << val << " "; - } - s << endl; - } -} - -// --------------------------------------------------------------------- -// class Metric -// --------------------------------------------------------------------- -Metric::Metric( - string name, - CUpti_MetricID id, - vector events, - CUpti_MetricEvaluationMode eval_mode, - CuptiMetricApi& cupti_metrics) - : name(std::move(name)), - id_(id), - events_(std::move(events)), - evalMode_(eval_mode), - cuptiMetrics_(cupti_metrics), - valueKind_(cuptiMetrics_.valueKind(id)) {} - -// Return per-SM vector as well as total -struct Metric::CalculatedValues Metric::calculate( - map& event_map, - nanoseconds sample_duration, - const SampleSlice& slice) { - vector metric_values; - vector ev_values; - ev_values.reserve(events_.size()); - if (evalMode_ & CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE) { - int instance_count = instanceCount(event_map); - metric_values.reserve(instance_count); - for (int i = 0; i < instance_count; i++) { - ev_values.clear(); - for (CUpti_EventID event_id : events_) { - ev_values.push_back(event_map[event_id].sumInstance(i, slice)); - } - metric_values.push_back(cuptiMetrics_.calculate( - id_, valueKind_, events_, ev_values, sample_duration.count())); - } - } - - // FIXME: Check assumption that all instances are profiled - ev_values.clear(); - for (CUpti_EventID event_id : events_) { - ev_values.push_back(event_map[event_id].sumAll(slice)); - } - SampleValue total = cuptiMetrics_.calculate( - id_, valueKind_, events_, ev_values, sample_duration.count()); - if (evalMode_ & CUPTI_METRIC_EVALUATION_MODE_AGGREGATE) { - metric_values.push_back(total); - } - return {metric_values, std::move(total)}; -} - -void Metric::printDescription(ostream& s) const { - s << fmt::format("{} ({})", name, fmt::join(events_, ",")) << endl; -} - -// --------------------------------------------------------------------- -// class EventGroupSet -// --------------------------------------------------------------------- - -// Each domain has a set of counters. -// Some counters in a domain can be collected simultaneously in a "group" -// Counters from different domains can also be collected at the same time -// Therefore we have a "set of groups", or group set, with counters that -// can all be collected at once. -EventGroupSet::EventGroupSet( - CUpti_EventGroupSet& set, - map& events, - CuptiEventApi& cupti) - : set_(set), events_(events), cuptiEvents_(cupti), enabled_(false) { - for (int g = 0; g < set.numEventGroups; g++) { - CUpti_EventGroup grp = set.eventGroups[g]; - // Profile all domain instances - cuptiEvents_.enablePerInstance(grp); - uint32_t instance_count = cuptiEvents_.instanceCount(grp); - for (const auto& id : cuptiEvents_.eventsInGroup(grp)) { - VLOG(0) << "Instance count for " << id << ":" << instance_count; - events_[id].instanceCount = instance_count; - } - } -} - -EventGroupSet::~EventGroupSet() { - // Disable EventGroupSet in Cupti. - if (enabled_) { - setEnabled(false); - } -} - -// Enable or disable this group set -void EventGroupSet::setEnabled(bool enabled) { - if (enabled && !enabled_) { - cuptiEvents_.enableGroupSet(set_); - } else if (!enabled && enabled_) { - cuptiEvents_.disableGroupSet(set_); - } - enabled_ = enabled; -} - -// Collect counter values for each counter in group set -void EventGroupSet::collectSample() { - auto timestamp = system_clock::now(); - for (int g = 0; g < set_.numEventGroups; g++) { - CUpti_EventGroup grp = set_.eventGroups[g]; - for (const auto& id : cuptiEvents_.eventsInGroup(grp)) { - Event& ev = events_[id]; - vector vals(ev.instanceCount); - // FIXME: Use cuptiEventGroupReadAllEvents - cuptiEvents_.readEvent(grp, id, vals); - - if (VLOG_IS_ON(0)) { - for (int64_t v : vals) { - if (v == CUPTI_EVENT_OVERFLOW) { - LOG(WARNING) << "Counter overflow detected " - << "- decrease sample period!" << endl; - } - } - } - - ev.addSample(timestamp, vals); - } - } - - if (VLOG_IS_ON(1)) { - auto t2 = system_clock::now(); - VLOG(1) << "Device " << cuptiEvents_.device() << " Sample (us): " - << duration_cast(t2 - timestamp).count(); - } -} - -// Print names of events in this group set, ordered by group -void EventGroupSet::printDescription(ostream& s) const { - for (int g = 0; g < set_.numEventGroups; g++) { - s << " Events in group " << g << ": "; - for (const auto& id : cuptiEvents_.eventsInGroup(set_.eventGroups[g])) { - s << id << " (" << events_[id].name << ") "; - } - s << endl; - } -} - -// --------------------------------------------------------------------- -// class EventProfiler -// --------------------------------------------------------------------- - -// Find nearest factor of a number by linear search, -// starting at hi and lo - hi searches up and lo searches down -static int nearestFactor(int hi, int lo, int number) { - return number % hi == 0 - ? hi - : number % lo == 0 ? lo : nearestFactor(hi + 1, lo - 1, number); -} - -static int nearestFactor(int count, int max) { - return nearestFactor(count, count, max); -} - -void EventProfiler::initEvents(const std::set& eventNames) { - events_.clear(); - // Build event map - for (const auto& name : eventNames) { - events_.emplace(cuptiEvents_->eventId(name), name); - } -} - -void EventProfiler::initMetrics(const std::set& metricNames) { - metrics_.clear(); - // Add events from metrics - metrics_.reserve(metricNames.size()); - for (const auto& metric_name : metricNames) { - CUpti_MetricID metric_id = cuptiMetrics_->idFromName(metric_name); - if (metric_id == ~0) { - continue; - } - - const auto& events = cuptiMetrics_->events(metric_id); - vector event_ids; - event_ids.reserve(events.size()); - for (const auto& pair : events) { - CUpti_EventID id = pair.first; - const string& event_name = pair.second; - if (event_name.empty()) { - // For unnamed events, use metric name and event id - // FIXME: For subsequent metrics using the same event, - // this will be confusing - events_.emplace(id, metric_name + "_" + event_name); - } else { - events_.emplace(id, event_name); - } - event_ids.push_back(id); - } - metrics_.emplace_back( - metric_name, - metric_id, - event_ids, - cuptiMetrics_->evaluationMode(metric_id), - *cuptiMetrics_); - } -} - -bool EventProfiler::initEventGroups() { - sets_.clear(); - if (eventGroupSets_) { - cuptiEvents_->destroyGroupSets(eventGroupSets_); - eventGroupSets_ = nullptr; - } - if (events_.empty()) { - return true; - } - - // Determine sets of groups to be collected - vector ids; - ids.reserve(events_.size()); - for (const auto& ev : events_) { - ids.push_back(ev.first); - } - eventGroupSets_ = cuptiEvents_->createGroupSets(ids); - VLOG(0) << "Number of group sets: " << eventGroupSets_->numSets; - for (int i = 0; i < eventGroupSets_->numSets; i++) { - sets_.push_back( - EventGroupSet(eventGroupSets_->sets[i], events_, *cuptiEvents_)); - } - return !sets_.empty(); -} - -static unique_ptr alignAndValidateConfigs( - Config& base, - Config* onDemand) { - auto now = system_clock::now(); - if (!onDemand || - now > - (onDemand->eventProfilerOnDemandStartTime() + - onDemand->eventProfilerOnDemandDuration())) { - base.validate(now); - return base.clone(); - } - - auto res = base.clone(); - res->addEvents(onDemand->eventNames()); - res->addMetrics(onDemand->metricNames()); - - int sample_period = - std::min(base.samplePeriod().count(), onDemand->samplePeriod().count()); - if (sample_period < base.samplePeriod().count() && - (base.samplePeriod().count() % sample_period) != 0) { - sample_period = nearestFactor(sample_period, base.samplePeriod().count()); - LOG(WARNING) - << "On-demand sample period must be a factor of base sample period. " - << "Adjusting from " << onDemand->samplePeriod().count() << "ms to " - << sample_period << "ms."; - } - base.setSamplePeriod(milliseconds(sample_period)); - base.validate(now); - res->setSamplePeriod(base.samplePeriod()); - res->setMultiplexPeriod(base.multiplexPeriod()); - res->validate(now); - onDemand->setSamplePeriod(base.samplePeriod()); - onDemand->setMultiplexPeriod(base.multiplexPeriod()); - onDemand->validate(now); - - return res; -} - -static milliseconds minReportPeriod(const Config& config, int num_sets) { - return config.multiplexPeriod() * num_sets; -} - -static bool canSupportReportPeriod(const Config& config, int num_sets) { - // Can we get through the groups an even number per report period? - milliseconds min_report_period = minReportPeriod(config, num_sets); - return (config.reportPeriod().count() % min_report_period.count()) == 0; -} - -static int completeSamplesPerReport(const Config& config, int num_sets) { - if (num_sets <= 1) { - return config.reportPeriod() / config.samplePeriod(); - } - // Numnber of complete sample collections in the report period - // E.g. if report period is 10000ms, sample period 500ms, - // multiplex period 2000ms and num_sets is 5 then # of complete samples is - // (2000ms / 500ms) * (10000ms / 2000ms / 5) = 4 * 1 = 4 - int samples_per_multiplex_period = - config.multiplexPeriod() / config.samplePeriod(); - int multiplex_periods_per_report = - config.reportPeriod() / config.multiplexPeriod(); - return (multiplex_periods_per_report / num_sets) * - samples_per_multiplex_period; -} - -static bool canSupportSamplesPerReport(const Config& config, int num_sets) { - // Can samples per report can be honored with an exact *full* set of samples? - // We don't support partial samples at this point. - int full_samples_per_report = completeSamplesPerReport(config, num_sets); - return (full_samples_per_report % config.samplesPerReport()) == 0; -} - -static void adjustConfig(Config& config, int num_sets) { - // Don't change sample period and multiplex period here, since that can - // cause overflows and perf degradation. Report period and samples per - // report is OK to change (with warning). - if (!canSupportReportPeriod(config, num_sets)) { - milliseconds min_report_period = minReportPeriod(config, num_sets); - LOG(WARNING) << "Report period must be a multiple of " - << min_report_period.count() << "ms (" << num_sets - << " event sets * " << config.multiplexPeriod().count() - << "ms multiplex period), in order to get complete samples."; - auto new_report_period = - Config::alignUp(config.reportPeriod(), min_report_period); - double sf = - ((double)new_report_period.count()) / config.reportPeriod().count(); - int new_samples_per_report = std::round(config.samplesPerReport() * sf); - LOG(WARNING) << "Adjusting report period from " - << config.reportPeriod().count() << "ms to " - << new_report_period.count() << "ms"; - if (new_samples_per_report != config.samplesPerReport()) { - LOG(WARNING) << "Adjusting samples per report from " - << config.samplesPerReport() << " to " - << new_samples_per_report; - } - config.setReportPeriod(new_report_period); - config.setSamplesPerReport(new_samples_per_report); - } - // Ensure that samples per report can be honored with - // an exact *full* set of samples. Don't support partial - // samples at this point. - if (!canSupportSamplesPerReport(config, num_sets)) { - int full_samples_per_report = completeSamplesPerReport(config, num_sets); - int adjusted_count = - nearestFactor(config.samplesPerReport(), full_samples_per_report); - LOG(WARNING) - << "Samples per report must be such that an even number of " - << "complete samples can be aggregated in each report period. Adjusting" - << " from " << config.samplesPerReport() << " to " << adjusted_count - << " (complete sample count is " << full_samples_per_report << ")"; - config.setSamplesPerReport(adjusted_count); - } -} - -// Prepare profiler -EventProfiler::EventProfiler( - std::unique_ptr cupti_events, - std::unique_ptr cupti_metrics, - vector>& loggers, - vector>& onDemandLoggers) - : cuptiEvents_(std::move(cupti_events)), - cuptiMetrics_(std::move(cupti_metrics)), - loggers_(loggers), - onDemandLoggers_(onDemandLoggers) {} - -void EventProfiler::reportSamples() { - dispatchSamples(*config_, loggers_, baseSamples_); - baseSamples_ += completeSamplesPerReport(*config_, sets_.size()); -} - -void EventProfiler::reportOnDemandSamples() { - dispatchSamples(*onDemandConfig_, onDemandLoggers_, onDemandSamples_); - onDemandSamples_ += completeSamplesPerReport(*onDemandConfig_, sets_.size()); -} - -EventProfiler::~EventProfiler() { - if (eventGroupSets_) { - for (auto& set : sets_) { - set.setEnabled(false); - } - cuptiEvents_->destroyGroupSets(eventGroupSets_); - } - VLOG(0) << "Stopped event profiler for device " << device(); -} - -void EventProfiler::updateLoggers(Config& config, Config* on_demand_config) { - // Update loggers. - for (auto& logger : loggers_) { - std::lock_guard lock(logMutex()); - logger->update(config); - } - - if (on_demand_config) { - // Update onDemand loggers. - for (auto& logger : onDemandLoggers_) { - std::lock_guard lock(logMutex()); - logger->update(*on_demand_config); - } - } -} - -bool EventProfiler::applyConfig(const Config& config) { - // Initialize events, metrics, and event group sets. - // TODO: Send warnings / errors back to dyno for onDemand config - try { - if (!initEventsAndMetrics(config)) { - return false; - } - } catch (const std::exception& ex) { - LOG(WARNING) << "Failed to apply config (" << ex.what() << ")"; - return false; - } - - return true; -} - -bool EventProfiler::initEventsAndMetrics(const Config& config) { - initEvents(config.eventNames()); - initMetrics(config.metricNames()); - // We now have the total list of events to collect - // They need to be organized into groups for multiplexing - if (!initEventGroups()) { - LOG(WARNING) << "No events/metrics initialized successfully"; - return false; - } - - if (VLOG_IS_ON(1)) { - printMetrics(LIBKINETO_DBG_STREAM); - printSets(LIBKINETO_DBG_STREAM); - } - return true; -} - -void EventProfiler::printSets(ostream& s) const { - for (int i = 0; i < sets_.size(); i++) { - s << "Set " << i << endl; - sets_[i].printDescription(s); - } -} - -void EventProfiler::printMetrics(ostream& s) const { - s << "Metrics:" << endl; - for (const Metric& m : metrics_) { - m.printDescription(s); - } -} - -void EventProfiler::printAllSamples(ostream& s, CUdevice device) const { - for (const auto& pair : events_) { - const Event& ev = pair.second; - ev.printSamples(s, device); - } -} - -void EventProfiler::enableNextCounterSet() { - if (sets_.size() > 1) { - auto t1 = system_clock::now(); - - VLOG(1) << "Disabling set " << curEnabledSet_; - sets_[curEnabledSet_].setEnabled(false); - curEnabledSet_ = (curEnabledSet_ + 1) % sets_.size(); - VLOG(1) << "Enabling set " << curEnabledSet_; - sets_[curEnabledSet_].setEnabled(true); - - if (VLOG_IS_ON(1)) { - auto t2 = system_clock::now(); - VLOG(1) << "Switch (us): " - << duration_cast(t2 - t1).count(); - } - } -} - -// Notify listeners of collected samples -void EventProfiler::dispatchSamples( - const Config& config, - const vector>& loggers, - int sample_offset) { - Sample sample(events_.size() + metrics_.size()); - // Normalize values to per second - auto delta = config.reportPeriod() / config.samplesPerReport(); - double sf = 1000.0 * sets_.size() / delta.count(); - for (int i = 0; i < config.samplesPerReport(); i++) { - sample.stats.clear(); - sample.deltaMsec = (delta * i).count(); - SampleSlice slice = {sample_offset, i, config.samplesPerReport()}; - VLOG(1) << "Slice: " << sample_offset << ", " << i << ", " - << config.samplesPerReport(); - for (const auto& pair : events_) { - const Event& ev = pair.second; - int64_t total = std::round(sf * ev.sumAll(slice)); - PercentileList pcs = initPercentiles(config.percentiles()); - normalize(ev.percentiles(pcs, slice), sf); - sample.stats.push_back({ev.name, std::move(pcs), SampleValue(total)}); - } - - for (auto& m : metrics_) { - // calculate returns a pair of per-SM vector and a total - auto vals = m.calculate(events_, delta, slice); - PercentileList pcs = initPercentiles(config.percentiles()); - sample.stats.push_back( - {m.name, std::move(percentiles(vals.perInstance, pcs)), vals.total}); - } - - for (auto& logger : loggers) { - std::lock_guard lock(logMutex()); - logger->handleSample(device(), sample, config.ipcFabricEnabled()); - } - } - - if (VLOG_IS_ON(2)) { - printAllSamples(LIBKINETO_DBG_STREAM, device()); - } -} - -void EventProfiler::configure(Config& config, Config* onDemandConfig) { - if (!sets_.empty()) { - sets_[curEnabledSet_].setEnabled(false); - clearSamples(); - } - - config_ = config.clone(); - onDemandConfig_ = onDemandConfig ? onDemandConfig->clone() : nullptr; - mergedConfig_ = alignAndValidateConfigs(*config_, onDemandConfig_.get()); - if (!applyConfig(*mergedConfig_)) { - LOG(WARNING) << "Failed to apply config!"; - mergedConfig_ = config_->clone(); - applyConfig(*config_); - } - if (!sets_.empty()) { - // Make timing adjustments based on multiplexing requirements. - adjustConfig(*config_, sets_.size()); - if (onDemandConfig_) { - int duration = onDemandConfig_->eventProfilerOnDemandDuration().count(); - LOG(INFO) << "On demand profiler activated for " << duration << " secs"; - adjustConfig(*onDemandConfig_, sets_.size()); - } - // If events or metrics were added or removed, need to tell loggers - updateLoggers(*config_, onDemandConfig_.get()); - } - - curEnabledSet_ = 0; - if (!sets_.empty()) { - sets_[0].setEnabled(true); - } else { - VLOG(0) << "No counters profiled!"; - } - - baseSamples_ = 0; - onDemandSamples_ = 0; -} - -void EventProfiler::collectSample() { - if (sets_.empty()) { - return; - } - sets_[curEnabledSet_].collectSample(); - if (VLOG_IS_ON(1)) { - printAllSamples(LIBKINETO_DBG_STREAM, device()); - } -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfiler.h b/plugins/tensorboard-plugins/libkineto/src/EventProfiler.h deleted file mode 100644 index fafd5b9bb8336b28b210ba58d588d3a798a73969..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/EventProfiler.h +++ /dev/null @@ -1,341 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "Config.h" -#include "CuptiEventApi.h" -#include "CuptiMetricApi.h" -#include "SampleListener.h" - -namespace KINETO_NAMESPACE { - -// Helper function for computing percentiles (nearest-rank). -// Modifies the input. -template -inline PercentileList& percentiles(std::vector values, PercentileList& pcs) { - auto size = values.size(); - for (auto& x : pcs) { - int idx = std::min(size - 1, (x.first * size) / 100); - std::nth_element(values.begin(), values.begin() + idx, values.end()); - x.second = SampleValue(values[idx]); - } - return pcs; -} - -// Helper function for normalizing a percentile list -// Modifies the input -inline PercentileList& normalize(PercentileList& pcs, double sf) { - for (auto& pc : pcs) { - pc.second *= sf; - } - return pcs; -} - -// A slice of the sample buffer -struct SampleSlice { - // Start offset (samples) - int offset; - // Slice number - int index; - // Out of this many - int count; -}; - -// A sampled event -class Event { - public: - /* implicit */ Event(std::string name) : name(std::move(name)) {} - /* implicit */ Event(const char* name) : name(name) {} - Event() : name("INVALID") {} - - Event(const Event&) = delete; - Event& operator=(const Event&) = delete; - Event(Event&&) = default; - Event& operator=(Event&&) = default; - - void addSample( - std::chrono::time_point timestamp, - const std::vector& values) { - assert(values.size() == instanceCount); - samples_.emplace_back(timestamp, values); - } - - // Sum samples for a single domain instance - int64_t sumInstance(int i, const SampleSlice& slice) const; - - // Sum all samples across all domain instances - int64_t sumAll(const SampleSlice& slice) const; - - // Create list of percentiles - PercentileList& percentiles(PercentileList& pcs, const SampleSlice& slice) - const; - - void eraseSamples(int count) { - auto end = samples_.begin(); - std::advance(end, count); - samples_.erase(samples_.begin(), end); - } - - void clearSamples() { - samples_.clear(); - } - - int sampleCount() { - return samples_.size(); - } - - void printSamples(std::ostream& s, CUdevice device) const; - - // Event name (see nvprof --query-events) - std::string name; - - // Number of domain instances for this event, e.g. number of SMs - int instanceCount = 0; - - private: - std::pair toIdxRange(const SampleSlice& slice) const { - int size = (samples_.size() - slice.offset) / slice.count; - return std::make_pair(slice.offset + (slice.index * size), size); - } - - // List of collected samples, where each sample has values for - // one or more domain instances - using Sample = std::pair< - std::chrono::time_point, - std::vector>; - std::list samples_; -}; - -class Metric { - public: - Metric( - std::string name, - CUpti_MetricID id, - std::vector events, - CUpti_MetricEvaluationMode eval_mode, - CuptiMetricApi& cupti_metrics); - - struct CalculatedValues { - std::vector perInstance; - SampleValue total; - }; - - struct CalculatedValues calculate( - std::map& events, - std::chrono::nanoseconds sample_duration, - const SampleSlice& slice); - - int instanceCount(std::map& events) { - return events[events_[0]].instanceCount; - } - - void printDescription(std::ostream& s) const; - - std::string name; - - private: - CUpti_MetricID id_; - std::vector events_; - CUpti_MetricEvaluationMode evalMode_; - // Calls to CUPTI is encapsulated behind this interface - CuptiMetricApi& cuptiMetrics_; - CUpti_MetricValueKind valueKind_; -}; - -/** - * A set of event groups. - * Holds all the events that may be collected in a single pass. - * A group contains one or more counters for a single domain. - * A group set contains zero or one groups per domain. - */ -class EventGroupSet { - public: - EventGroupSet( - CUpti_EventGroupSet& set, - std::map& events, - CuptiEventApi& cupti); - ~EventGroupSet(); - - EventGroupSet(const EventGroupSet&) = delete; - EventGroupSet& operator=(const EventGroupSet&) = delete; - EventGroupSet(EventGroupSet&&) = default; - EventGroupSet& operator=(EventGroupSet&&) = delete; - - // Number of groups = number of domains profiled - int groupCount() const { - return set_.numEventGroups; - } - - void setEnabled(bool enabled); - // Take a sample of counters in this group set - void collectSample(); - void printDescription(std::ostream& s) const; - - private: - CUpti_EventGroupSet& set_; - std::map& events_; - // Calls to CUPTI is encapsulated behind this interface - CuptiEventApi& cuptiEvents_; - bool enabled_; -}; - -// The sampler -class EventProfiler { - public: - explicit EventProfiler( - std::unique_ptr cupti_events, - std::unique_ptr cupti_metrics, - std::vector>& loggers, - std::vector>& onDemandLoggers); - EventProfiler(const EventProfiler&) = delete; - EventProfiler& operator=(const EventProfiler&) = delete; - ~EventProfiler(); - - void configure(Config& config, Config* onDemandConfig); - - bool isOnDemandActive() { - return !!onDemandConfig_; - } - - // Print the counter sets. Multiple sets will be multiplexed. - void printSets(std::ostream& s) const; - - // Print metrics descriptions - void printMetrics(std::ostream& s) const; - - bool enableForDevice(Config& cfg); - - CUdevice device() { - return cuptiEvents_->device(); - } - - bool setContinuousMode() { - return cuptiEvents_->setContinuousMode(); - } - - std::chrono::milliseconds samplePeriod() { - return mergedConfig_->samplePeriod(); - } - - std::chrono::milliseconds multiplexPeriod() { - return mergedConfig_->multiplexPeriod(); - } - - std::chrono::milliseconds reportPeriod() { - return config_->reportPeriod(); - } - - std::chrono::milliseconds onDemandReportPeriod() { - return onDemandConfig_->reportPeriod(); - } - - // Read values of currently running counters. - void collectSample(); - - void reportSamples(); - void reportOnDemandSamples(); - - bool enabled() { - return sets_.size() > 0; - } - - bool multiplexEnabled() { - return sets_.size() > 1; - } - - // Multiplex counters. - void enableNextCounterSet(); - - void eraseReportedSamples() { - int erase_count = baseSamples_; - if (onDemandConfig_ && - onDemandConfig_->eventProfilerOnDemandDuration().count() > 0) { - erase_count = std::min(baseSamples_, onDemandSamples_); - } - eraseSamples(erase_count); - baseSamples_ -= erase_count; - onDemandSamples_ -= erase_count; - } - - void clearSamples() { - for (auto& pair : events_) { - pair.second.clearSamples(); - } - baseSamples_ = 0; - onDemandSamples_ = 0; - } - - private: - // Functions to initialize profiler based on Config settings. - bool applyConfig(const Config& config); - bool initEventsAndMetrics(const Config& config); - void initEvents(const std::set& eventNames); - void initMetrics(const std::set& metricNames); - bool initEventGroups(); - - PercentileList initPercentiles(const std::vector& percentiles) { - PercentileList res; - res.reserve(percentiles.size()); - for (int p : percentiles) { - res.emplace_back(p, SampleValue(0)); - } - return res; - } - - // Notify listeners of collected samples - void dispatchSamples( - const Config& config, - const std::vector>& loggers, - int report_nr); - - void eraseSamples(int count) { - for (auto& pair : events_) { - pair.second.eraseSamples(count); - } - } - - void updateLoggers(Config& config, Config* on_demand_config); - - // Print all collected samples since last clear. - void printAllSamples(std::ostream& s, CUdevice device) const; - - // Calls to CUPTI is encapsulated behind these interfaces - std::unique_ptr cuptiEvents_; - std::unique_ptr cuptiMetrics_; - // The CUpti API reports event IDs, we must map them to our event objects - std::map events_; - // List of metrics - std::vector metrics_; - // The countert sets needed to collect all counters - std::vector sets_; - // The event group set object returned by Cupti. - // Saved s.t. we can call cuptiEventGroupSetsDestroy to free memory when - // the object is no longer needed. - CUpti_EventGroupSets* eventGroupSets_ = nullptr; - // Current multiplexed counter set - int curEnabledSet_{0}; - - std::unique_ptr config_; - std::unique_ptr onDemandConfig_; - std::unique_ptr mergedConfig_; - int baseSamples_{0}; - int onDemandSamples_{0}; - - // Shared between profiler threads - // Vectors are read-only but calling loggers require lock - const std::vector>& loggers_; - const std::vector>& onDemandLoggers_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp b/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp deleted file mode 100644 index 0427cc7a90cbc49d31262bcce63f1f81c5b6293f..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp +++ /dev/null @@ -1,423 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "EventProfilerController.h" - -#include -#include -#include - -#include "ConfigLoader.h" -#include "CuptiEventApi.h" -#include "CuptiMetricApi.h" -#include "EventProfiler.h" -#include "output_csv.h" - -#include "Logger.h" -#include "ThreadUtil.h" - -using namespace std::chrono; -using std::unique_ptr; -using std::vector; - -namespace KINETO_NAMESPACE { - -namespace { - -vector(const Config&)>>& -loggerFactories() { - static vector(const Config&)>> - factories; - return factories; -} - -vector(const Config&)>>& -onDemandLoggerFactories() { - static vector(const Config&)>> - factories; - return factories; -} - -vector> makeLoggers(const Config& config) { - vector> loggers; - for (const auto& factory : loggerFactories()) { - loggers.push_back(factory(config)); - } - loggers.push_back(std::make_unique()); - loggers.push_back(std::make_unique()); - return loggers; -} - -vector> makeOnDemandLoggers( - const Config& config) { - vector> loggers; - for (const auto& factory : onDemandLoggerFactories()) { - loggers.push_back(factory(config)); - } - loggers.push_back(std::make_unique()); - return loggers; -} - -vector>& loggers(const Config& config) { - static auto res = makeLoggers(config); - return res; -} - -vector>& onDemandLoggers( - const Config& config) { - static auto res = makeOnDemandLoggers(config); - return res; -} - -} // anon namespace - -// Keep an eye on profiling threads. -// We've observed deadlocks in Cuda11 in libcuda / libcupti.. -namespace detail { - -class HeartbeatMonitor { - - public: - ~HeartbeatMonitor() { - stopMonitoring(); - } - - static HeartbeatMonitor& instance() { - static HeartbeatMonitor monitor; - return monitor; - } - - void profilerHeartbeat() { - int32_t tid = systemThreadId(); - std::lock_guard lock(mutex_); - profilerAliveMap_[tid]++; - } - - void setPeriod(seconds period) { - { - std::lock_guard lock(mutex_); - if (period_ == period) { - return; - } - period_ = period; - } - if (period == seconds(0)) { - stopMonitoring(); - } else { - startMonitoring(); - } - } - - private: - HeartbeatMonitor() = default; - - void monitorLoop() { - std::unique_lock lock(mutex_); - while(!stopMonitor_) { - auto cv_status = condVar_.wait_for(lock, seconds(period_)); - // Don't perform check on spurious wakeup or on notify - if (cv_status == std::cv_status::timeout) { - for (auto& pair : profilerAliveMap_) { - int32_t tid = pair.first; - int& i = pair.second; - if (i == 0) { - LOG(ERROR) << "Thread " << tid << " appears stuck!"; - } - i = 0; - } - } - } - } - - void startMonitoring() { - if (!monitorThread_) { - VLOG(0) << "Starting monitoring thread"; - stopMonitor_ = false; - monitorThread_ = std::make_unique( - &HeartbeatMonitor::monitorLoop, this); - } - } - - void stopMonitoring() { - if (monitorThread_) { - VLOG(0) << "Stopping monitoring thread"; - stopMonitor_ = true; - condVar_.notify_one(); - monitorThread_->join(); - monitorThread_ = nullptr; - VLOG(0) << "Monitoring thread terminated"; - } - } - - std::map profilerAliveMap_; - std::unique_ptr monitorThread_; - std::mutex mutex_; - std::condition_variable condVar_; - std::atomic_bool stopMonitor_{false}; - seconds period_{0}; -}; - -} // namespace detail - -namespace { -// Profiler map singleton -std::map>& profilerMap() { - static std::map> instance; - return instance; -} - -void reportLateSample( - int sleepMs, - int sampleMs, - int reportMs, - int reprogramMs) { - LOG_EVERY_N(WARNING, 10) << "Lost sample due to delays (ms): " << sleepMs - << ", " << sampleMs << ", " << reportMs << ", " - << reprogramMs; -} - -void configureHeartbeatMonitor( - detail::HeartbeatMonitor& monitor, const Config& base, const Config* onDemand) { - seconds base_period = - base.eventProfilerHeartbeatMonitorPeriod(); - seconds on_demand_period = !onDemand ? seconds(0) : - onDemand->eventProfilerHeartbeatMonitorPeriod(); - monitor.setPeriod( - on_demand_period > seconds(0) ? on_demand_period : base_period); -} - -} // anon namespace - -void EventProfilerController::addLoggerFactory( - std::function(const Config&)> factory) { - loggerFactories().push_back(factory); -} - -void EventProfilerController::addOnDemandLoggerFactory( - std::function(const Config&)> factory) { - onDemandLoggerFactories().push_back(factory); -} - -EventProfilerController::EventProfilerController( - CUcontext context, - ConfigLoader& configLoader, - detail::HeartbeatMonitor& heartbeatMonitor) - : configLoader_(configLoader), heartbeatMonitor_(heartbeatMonitor) { - auto cupti_events = std::make_unique(context); - auto cupti_metrics = - std::make_unique(cupti_events->device()); - configLoader_.addHandler( - ConfigLoader::ConfigKind::EventProfiler, this); - auto config = configLoader.getConfigCopy(); - profiler_ = std::make_unique( - std::move(cupti_events), - std::move(cupti_metrics), - loggers(*config), - onDemandLoggers(*config)); - profilerThread_ = std::make_unique( - &EventProfilerController::profilerLoop, this); -} - -EventProfilerController::~EventProfilerController() { - if (profilerThread_) { - // signaling termination of the profiler loop - stopRunloop_ = true; - profilerThread_->join(); - } - configLoader_.removeHandler( - ConfigLoader::ConfigKind::EventProfiler, this); - VLOG(0) << "Stopped event profiler"; -} - -// Must be called under lock -void EventProfilerController::start(CUcontext ctx, ConfigLoader& configLoader) { - profilerMap()[ctx] = unique_ptr( - new EventProfilerController( - ctx, configLoader, detail::HeartbeatMonitor::instance())); -} - -// Must be called under lock -void EventProfilerController::stop(CUcontext ctx) { - profilerMap()[ctx] = nullptr; -} - -bool EventProfilerController::canAcceptConfig() { - std::lock_guard guard(mutex_); - return !newOnDemandConfig_; -} - -void EventProfilerController::acceptConfig(const Config& config) { - if (config.eventProfilerOnDemandDuration().count() == 0) { - // Ignore - not for this profiler - return; - } - std::lock_guard guard(mutex_); - if (newOnDemandConfig_) { - LOG(ERROR) << "On demand request already queued - ignoring new request"; - return; - } - newOnDemandConfig_ = config.clone(); - LOG(INFO) << "Received new on-demand config"; -} - -bool EventProfilerController::enableForDevice(Config& cfg) { - // FIXME: Use device unique id! - if (!cfg.eventProfilerEnabledForDevice(profiler_->device())) { - return false; - } - // context count includes the new context - int instances = configLoader_.contextCountForGpu(profiler_->device()); - VLOG(0) << "Device context count: " << instances; - return instances >= 0 && instances <= cfg.maxEventProfilersPerGpu(); -} - -void EventProfilerController::profilerLoop() { - // We limit the number of profilers that can exist per GPU - auto config = configLoader_.getConfigCopy(); - if (!enableForDevice(*config)) { - VLOG(0) << "Not starting EventProfiler - profilers for GPU " - << profiler_->device() << " exceeds profilers per GPU limit (" - << config->maxEventProfilersPerGpu() << ")"; - return; - } - - if (!profiler_->setContinuousMode()) { - VLOG(0) << "Continuous mode not supported for GPU " - << profiler_->device() << ". Not starting Event Profiler."; - return; - } - - VLOG(0) << "Starting Event Profiler for GPU " << profiler_->device(); - setThreadName("CUPTI Event Profiler"); - - time_point next_sample_time; - time_point next_report_time; - time_point next_on_demand_report_time; - time_point next_multiplex_time; - std::unique_ptr on_demand_config = nullptr; - bool reconfigure = true; - bool restart = true; - int report_count = 0; - int on_demand_report_count = 0; - while (!stopRunloop_) { - heartbeatMonitor_.profilerHeartbeat(); - if (configLoader_.hasNewConfig(*config)) { - config = configLoader_.getConfigCopy(); - VLOG(0) << "Base config changed"; - report_count = 0; - reconfigure = true; - } - - auto now = system_clock::now(); - if (on_demand_config && - now > (on_demand_config->eventProfilerOnDemandStartTime() + - on_demand_config->eventProfilerOnDemandDuration())) { - on_demand_config = nullptr; - LOG(INFO) << "On-demand profiling complete"; - reconfigure = true; - } - - if (!profiler_->isOnDemandActive()) { - std::lock_guard lock(mutex_); - if (newOnDemandConfig_) { - VLOG(0) << "Received on-demand config, reconfiguring"; - on_demand_config = std::move(newOnDemandConfig_); - reconfigure = true; - on_demand_report_count = 0; - } - } - - if (reconfigure) { - try { - profiler_->configure(*config, on_demand_config.get()); - } catch (const std::exception& ex) { - LOG(ERROR) << "Encountered error while configuring event profiler: " - << ex.what(); - // Exit profiling entirely when encountering an error here - // as it indicates a serious problem or bug. - break; - } - configureHeartbeatMonitor( - heartbeatMonitor_, *config, on_demand_config.get()); - reconfigure = false; - restart = true; - } - - if (restart) { - now = system_clock::now(); - next_sample_time = now + profiler_->samplePeriod(); - next_report_time = now + profiler_->reportPeriod(); - if (profiler_->isOnDemandActive()) { - next_on_demand_report_time = now + profiler_->onDemandReportPeriod(); - } - next_multiplex_time = now + profiler_->multiplexPeriod(); - // Collect an initial sample and throw it away - // The next sample is the first valid one - profiler_->collectSample(); - profiler_->clearSamples(); - restart = false; - } - - auto start_sleep = now; - while (now < next_sample_time) { - /* sleep override */ - std::this_thread::sleep_for(next_sample_time - now); - now = system_clock::now(); - } - int sleep_time = duration_cast(now - start_sleep).count(); - - auto start_sample = now; - profiler_->collectSample(); - now = system_clock::now(); - int sample_time = duration_cast(now - start_sample).count(); - - next_sample_time += profiler_->samplePeriod(); - if (now > next_sample_time) { - reportLateSample(sleep_time, sample_time, 0, 0); - restart = true; - continue; - } - - auto start_report = now; - if (now > next_report_time) { - VLOG(1) << "Report #" << report_count++; - profiler_->reportSamples(); - next_report_time += profiler_->reportPeriod(); - } - if (profiler_->isOnDemandActive() && now > next_on_demand_report_time) { - VLOG(1) << "OnDemand Report #" << on_demand_report_count++; - profiler_->reportOnDemandSamples(); - next_on_demand_report_time += profiler_->onDemandReportPeriod(); - } - profiler_->eraseReportedSamples(); - now = system_clock::now(); - int report_time = duration_cast(now - start_report).count(); - - if (now > next_sample_time) { - reportLateSample(sleep_time, sample_time, report_time, 0); - restart = true; - continue; - } - - auto start_multiplex = now; - if (profiler_->multiplexEnabled() && now > next_multiplex_time) { - profiler_->enableNextCounterSet(); - next_multiplex_time += profiler_->multiplexPeriod(); - } - now = system_clock::now(); - int multiplex_time = - duration_cast(now - start_multiplex).count(); - - if (now > next_sample_time) { - reportLateSample(sleep_time, sample_time, report_time, multiplex_time); - restart = true; - } - - VLOG(0) << "Runloop execution time: " - << duration_cast(now - start_sample).count() << "ms"; - } - - VLOG(0) << "Device " << profiler_->device() - << ": Exited event profiling loop"; -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h b/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h deleted file mode 100644 index 007a82faa9289ada9256d09907167471eb6520b9..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h +++ /dev/null @@ -1,63 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include "ConfigLoader.h" - -namespace KINETO_NAMESPACE { - -class Config; -class ConfigLoader; -class EventProfiler; -class SampleListener; - -namespace detail { -class HeartbeatMonitor; -} - -class EventProfilerController : public ConfigLoader::ConfigHandler { - public: - EventProfilerController(const EventProfilerController&) = delete; - EventProfilerController& operator=(const EventProfilerController&) = delete; - - ~EventProfilerController(); - - static void start(CUcontext ctx, ConfigLoader& configLoader); - static void stop(CUcontext ctx); - - static void addLoggerFactory( - std::function(const Config&)> factory); - - static void addOnDemandLoggerFactory( - std::function(const Config&)> factory); - - bool canAcceptConfig() override; - - void acceptConfig(const Config& config) override; - - private: - explicit EventProfilerController( - CUcontext context, - ConfigLoader& configLoader, - detail::HeartbeatMonitor& heartbeatMonitor); - bool enableForDevice(Config& cfg); - void profilerLoop(); - - ConfigLoader& configLoader_; - std::unique_ptr newOnDemandConfig_; - detail::HeartbeatMonitor& heartbeatMonitor_; - std::unique_ptr profiler_; - std::unique_ptr profilerThread_; - std::atomic_bool stopRunloop_{false}; - std::mutex mutex_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp b/plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp deleted file mode 100644 index 4e00b1256c4fa301e288e619ee9ef8c56c8b8569..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "GenericTraceActivity.h" -#include "output_base.h" - -namespace libkineto { - void GenericTraceActivity::log(ActivityLogger& logger) const { - logger.handleGenericActivity(*this); - } -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp b/plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp deleted file mode 100644 index f0106578811837c9cc677def30d5697d43a94221..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "ILoggerObserver.h" - -#if !USE_GOOGLE_LOG - -#include -#include - -namespace libkineto { - -struct LoggerTypeName { - constexpr LoggerTypeName(const char* n, LoggerOutputType t) : name(n), type(t) {}; - const char* name; - LoggerOutputType type; -}; - -static constexpr std::array LoggerMap{{ - {"VERBOSE", LoggerOutputType::VERBOSE}, - {"INFO", LoggerOutputType::INFO}, - {"WARNING", LoggerOutputType::WARNING}, - {"ERROR", LoggerOutputType::ERROR}, - {"STAGE", LoggerOutputType::STAGE}, - {"???", LoggerOutputType::ENUM_COUNT} -}}; - -static constexpr bool matchingOrder(int idx = 0) { - return LoggerMap[idx].type == LoggerOutputType::ENUM_COUNT || - ((idx == (int) LoggerMap[idx].type) && matchingOrder(idx + 1)); -} -static_assert(matchingOrder(), "LoggerTypeName map is out of order"); - -const char* toString(LoggerOutputType t) { - if(t < VERBOSE || t >= ENUM_COUNT) { - return LoggerMap[ENUM_COUNT].name; - } - return LoggerMap[(int)t].name; -} - -LoggerOutputType toLoggerOutputType(const std::string& str) { - for (int i = 0; i < LoggerTypeCount; i++) { - if (str == LoggerMap[i].name) { - return LoggerMap[i].type; - } - } - throw std::invalid_argument(fmt::format("Invalid activity type: {}", str)); -} - -} // namespace libkineto - - -#endif // !USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/Logger.cpp b/plugins/tensorboard-plugins/libkineto/src/Logger.cpp deleted file mode 100644 index dbde765f51f7a5f03c31a9c79e6d00ce9a2070b6..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/Logger.cpp +++ /dev/null @@ -1,136 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "Logger.h" -#include "ILoggerObserver.h" - -#ifndef USE_GOOGLE_LOG - -#include -#include -#include -#include -#include - -#include -#include - -#include "ThreadUtil.h" - -namespace KINETO_NAMESPACE { - -std::atomic_int Logger::severityLevel_{VERBOSE}; -std::atomic_int Logger::verboseLogLevel_{-1}; -std::atomic Logger::verboseLogModules_{~0ull}; - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wglobal-constructors" -std::mutex Logger::loggerObserversMutex_; -#pragma GCC diagnostic pop - - -Logger::Logger(int severity, int line, const char* filePath, int errnum) - : buf_(), out_(LIBKINETO_DBG_STREAM), errnum_(errnum), messageSeverity_(severity) { - buf_ << toString((LoggerOutputType) severity) << ":"; - - const auto tt = - std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); - const char* file = strrchr(filePath, '/'); - buf_ << fmt::format("{:%Y-%m-%d %H:%M:%S}", fmt::localtime(tt)) << " " - << processId() << ":" << systemThreadId() << " " - << (file ? file + 1 : filePath) << ":" << line << "] "; -} - -Logger::~Logger() { -#ifdef __linux__ - if (errnum_ != 0) { - thread_local char buf[1024]; - buf_ << " : " << strerror_r(errnum_, buf, sizeof(buf)); - } -#endif - - { - std::lock_guard guard(loggerObserversMutex_); - for (auto* observer : loggerObservers()) { - // Output to observers. Current Severity helps keep track of which bucket the output goes. - if (observer) { - observer->write(buf_.str(), (LoggerOutputType) messageSeverity_); - } - } - } - - // Finally, print to terminal or console. - out_ << buf_.str() << std::endl; -} - -void Logger::setVerboseLogModules(const std::vector& modules) { - uint64_t mask = 0; - if (modules.empty()) { - mask = ~0ull; - } else { - for (const std::string& name : modules) { - mask |= hash(name.c_str()); - } - } - verboseLogModules_ = mask; -} - -void Logger::addLoggerObserver(ILoggerObserver* observer) { - if (observer == nullptr) { - return; - } - std::lock_guard guard(loggerObserversMutex_); - loggerObservers().insert(observer); -} - -void Logger::removeLoggerObserver(ILoggerObserver* observer) { - std::lock_guard guard(loggerObserversMutex_); - loggerObservers().erase(observer); -} - -void Logger::addLoggerObserverDevice(int64_t device) { - std::lock_guard guard(loggerObserversMutex_); - for (auto observer : loggerObservers()) { - observer->addDevice(device); - } -} - -void Logger::addLoggerObserverEventCount(int64_t count) { - std::lock_guard guard(loggerObserversMutex_); - for (auto observer : loggerObservers()) { - observer->addEventCount(count); - } -} - -void Logger::setLoggerObserverTraceDurationMS(int64_t duration) { - std::lock_guard guard(loggerObserversMutex_); - for (auto observer : loggerObservers()) { - observer->setTraceDurationMS(duration); - } -} - -void Logger::setLoggerObserverTraceID(const std::string& tid) { - std::lock_guard guard(loggerObserversMutex_); - for (auto observer : loggerObservers()) { - observer->setTraceID(tid); - } -} - -void Logger::setLoggerObserverGroupTraceID(const std::string& gtid) { - std::lock_guard guard(loggerObserversMutex_); - for (auto observer : loggerObservers()) { - observer->setGroupTraceID(gtid); - } -} - -void Logger::addLoggerObserverDestination(const std::string& dest) { - std::lock_guard guard(loggerObserversMutex_); - for (auto observer : loggerObservers()) { - observer->addDestination(dest); - } -} - -} // namespace KINETO_NAMESPACE - -#endif // USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/Logger.h b/plugins/tensorboard-plugins/libkineto/src/Logger.h deleted file mode 100644 index 868fc84b9f4ee86d88805bed81468a5df6988257..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/Logger.h +++ /dev/null @@ -1,244 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -#define LIBKINETO_DBG_STREAM std::cerr - -#if USE_GOOGLE_LOG - -#include - -#define SET_LOG_SEVERITY_LEVEL(level) -#define SET_LOG_VERBOSITY_LEVEL(level, modules) -#define LOGGER_OBSERVER_ADD_DEVICE(device) -#define LOGGER_OBSERVER_ADD_EVENT_COUNT(count) -#define LOGGER_OBSERVER_SET_TRACE_DURATION_MS(duration) -#define LOGGER_OBSERVER_SET_TRACE_ID(tid) -#define LOGGER_OBSERVER_SET_GROUP_TRACE_ID(gtid) -#define LOGGER_OBSERVER_ADD_DESTINATION(dest) -#define UST_LOGGER_MARK_COMPLETED(stage) - -#else // !USE_GOOGLE_LOG -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "ILoggerObserver.h" - -#ifdef _MSC_VER -// unset a predefined ERROR (windows) -#undef ERROR -#endif // _MSC_VER - -namespace KINETO_NAMESPACE { - -class Logger { - public: - Logger(int severity, int line, const char* filePath, int errnum = 0); - ~Logger(); - - inline std::ostream& stream() { - return buf_; - } - - static inline void setSeverityLevel(int level) { - severityLevel_ = level; - } - - static inline int severityLevel() { - return severityLevel_; - } - - static inline void setVerboseLogLevel(int level) { - verboseLogLevel_ = level; - } - - static inline int verboseLogLevel() { - return verboseLogLevel_; - } - - // This is constexpr so that the hash for a file name is computed at compile - // time when used in the VLOG macros. - // This way, there is no string comparison for matching VLOG modules, - // only a comparison of pre-computed hashes. - // No fancy hashing needed here. It's pretty inefficient (one character - // at a time) but the strings are not large and it's not in the critical path. - static constexpr uint64_t rol(uint64_t val, int amount) { - return val << amount | val >> (63 - amount); - } - static constexpr uint64_t hash(const char* s) { - uint64_t hash = hash_rec(s, 0); - return hash & rol(0x41a0240682483014ull, hash & 63); - } - static constexpr uint64_t hash_rec(const char* s, int off) { - // Random constants! - return (!s[off] ? 57ull : (hash_rec(s, off + 1) * 293) ^ s[off]); - } - static constexpr const char* basename(const char* s, int off = 0) { - return !s[off] - ? s - : s[off] == '/' ? basename(&s[off + 1]) : basename(s, off + 1); - } - - static void setVerboseLogModules(const std::vector& modules); - - static inline uint64_t verboseLogModules() { - return verboseLogModules_; - } - - static void clearLoggerObservers() { - std::lock_guard g(loggerObserversMutex_); - loggerObservers().clear(); - } - - static void addLoggerObserver(ILoggerObserver* observer); - - static void removeLoggerObserver(ILoggerObserver* observer); - - static void addLoggerObserverDevice(int64_t device); - - static void addLoggerObserverEventCount(int64_t count); - - static void setLoggerObserverTraceDurationMS(int64_t duration); - - static void setLoggerObserverTraceID(const std::string& tid); - - static void setLoggerObserverGroupTraceID(const std::string& gtid); - - static void addLoggerObserverDestination(const std::string& dest); - - private: - std::stringstream buf_; - std::ostream& out_; - int errnum_; - int messageSeverity_; - static std::atomic_int severityLevel_; - static std::atomic_int verboseLogLevel_; - static std::atomic verboseLogModules_; - static std::set& loggerObservers() { - static auto* inst = new std::set(); - return *inst; - } - static std::mutex loggerObserversMutex_; -}; - -class VoidLogger { - public: - VoidLogger() {} - void operator&(std::ostream&) {} -}; - -} // namespace KINETO_NAMESPACE - -#ifdef LOG // Undefine in case these are already defined (quite likely) -#undef LOG -#undef LOG_IS_ON -#undef LOG_IF -#undef LOG_EVERY_N -#undef LOG_IF_EVERY_N -#undef DLOG -#undef DLOG_IF -#undef VLOG -#undef VLOG_IF -#undef VLOG_EVERY_N -#undef VLOG_IS_ON -#undef DVLOG -#undef LOG_FIRST_N -#undef CHECK -#undef DCHECK -#undef DCHECK_EQ -#undef PLOG -#undef PCHECK -#undef LOG_OCCURRENCES -#endif - -#define LOG_IS_ON(severity) \ - (severity >= libkineto::Logger::severityLevel()) - -#define LOG_IF(severity, condition) \ - !(LOG_IS_ON(severity) && (condition)) ? (void)0 : libkineto::VoidLogger() & \ - libkineto::Logger(severity, __LINE__, __FILE__).stream() - -#define LOG(severity) LOG_IF(severity, true) - -#define LOCAL_VARNAME_CONCAT(name, suffix) _##name##suffix##_ - -#define LOCAL_VARNAME(name) LOCAL_VARNAME_CONCAT(name, __LINE__) - -#define LOG_OCCURRENCES LOCAL_VARNAME(log_count) - -#define LOG_EVERY_N(severity, rate) \ - static int LOG_OCCURRENCES = 0; \ - LOG_IF(severity, LOG_OCCURRENCES++ % rate == 0) \ - << "(x" << LOG_OCCURRENCES << ") " - -template -struct __to_constant__ { - static const uint64_t val = n; -}; -#define FILENAME_HASH \ - __to_constant__::val -#define VLOG_IS_ON(verbosity) \ - (libkineto::Logger::verboseLogLevel() >= verbosity && \ - (libkineto::Logger::verboseLogModules() & FILENAME_HASH) == FILENAME_HASH) - -#define VLOG_IF(verbosity, condition) \ - LOG_IF(VERBOSE, VLOG_IS_ON(verbosity) && (condition)) - -#define VLOG(verbosity) VLOG_IF(verbosity, true) - -#define VLOG_EVERY_N(verbosity, rate) \ - static int LOG_OCCURRENCES = 0; \ - VLOG_IF(verbosity, LOG_OCCURRENCES++ % rate == 0) \ - << "(x" << LOG_OCCURRENCES << ") " - -#define PLOG(severity) \ - libkineto::Logger(severity, __LINE__, __FILE__, errno).stream() - -#define SET_LOG_SEVERITY_LEVEL(level) \ - libkineto::Logger::setSeverityLevel(level) - -#define SET_LOG_VERBOSITY_LEVEL(level, modules) \ - libkineto::Logger::setVerboseLogLevel(level); \ - libkineto::Logger::setVerboseLogModules(modules) - -// Logging the set of devices the trace is collect on. -#define LOGGER_OBSERVER_ADD_DEVICE(device_count) \ - libkineto::Logger::addLoggerObserverDevice(device_count) - -// Incrementing the number of events collected by this trace. -#define LOGGER_OBSERVER_ADD_EVENT_COUNT(count) \ - libkineto::Logger::addLoggerObserverEventCount(count) - -// Record duration of trace in milliseconds. -#define LOGGER_OBSERVER_SET_TRACE_DURATION_MS(duration) \ - libkineto::Logger::setLoggerObserverTraceDurationMS(duration) - -// Record the trace id when given. -#define LOGGER_OBSERVER_SET_TRACE_ID(tid) \ - libkineto::Logger::setLoggerObserverTraceID(tid) - -// Record the group trace id when given. -#define LOGGER_OBSERVER_SET_GROUP_TRACE_ID(gtid) \ - libkineto::Logger::setLoggerObserverGroupTraceID(gtid) - -// Log the set of destinations the trace is sent to. -#define LOGGER_OBSERVER_ADD_DESTINATION(dest) \ - libkineto::Logger::addLoggerObserverDestination(dest) - -// UST Logger Semantics to describe when a stage is complete. -#define UST_LOGGER_MARK_COMPLETED(stage) \ - LOG(libkineto::LoggerOutputType::STAGE) << "Completed Stage: " << stage - -#endif // USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h b/plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h deleted file mode 100644 index bb05aab218dc137cfe2f0107694a049ee2ea6508..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h +++ /dev/null @@ -1,70 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#if !USE_GOOGLE_LOG - -#include -#include -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "ILoggerObserver.h" - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -class LoggerCollector : public ILoggerObserver { - public: - LoggerCollector() : buckets_() {} - - void write(const std::string& message, LoggerOutputType ot = ERROR) override { - // Skip STAGE output type which is only used by USTLoggerCollector. - if (ot != STAGE) { - buckets_[ot].push_back(message); - } - } - - const std::map> extractCollectorMetadata() override { - return buckets_; - } - - void reset() override { - trace_duration_ms = 0; - event_count = 0; - destinations.clear(); - } - - void addDevice(const int64_t device) override { - devices.insert(device); - } - - void setTraceDurationMS(const int64_t duration) override { - trace_duration_ms = duration; - } - - void addEventCount(const int64_t count) override { - event_count += count; - } - - void addDestination(const std::string& dest) override { - destinations.insert(dest); - } - - protected: - std::map> buckets_; - - // These are useful metadata to collect from CUPTIActivityProfiler for internal tracking. - std::set devices; - int64_t trace_duration_ms{0}; - std::atomic event_count{0}; - std::set destinations; - -}; - -} // namespace KINETO_NAMESPACE - -#endif // !USE_GOOGLE_LOG diff --git a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp deleted file mode 100644 index 73eff13e2a08bcfecefb03f5b229bde89b7e96cb..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp +++ /dev/null @@ -1,569 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "RoctracerActivityApi.h" - -#include -#include -#include - -#include "Demangle.h" -#include "output_base.h" -#include "ThreadUtil.h" - -typedef uint64_t timestamp_t; - -static timestamp_t timespec_to_ns(const timespec& time) { - return ((timestamp_t)time.tv_sec * 1000000000) + time.tv_nsec; - } - -using namespace std::chrono; - -namespace KINETO_NAMESPACE { - -constexpr size_t kBufSize(2 * 1024 * 1024); - -RoctracerActivityApi& RoctracerActivityApi::singleton() { - static RoctracerActivityApi instance; - return instance; -} - -RoctracerActivityApi::RoctracerActivityApi() { - gpuTraceBuffers_ = std::make_unique>(); -} - -RoctracerActivityApi::~RoctracerActivityApi() { - disableActivities(std::set()); - endTracing(); -} - -void RoctracerActivityApi::pushCorrelationID(int id, CorrelationFlowType type) { -#ifdef HAS_ROCTRACER - if (!singleton().externalCorrelationEnabled_) { - return; - } - // placeholder -#endif -} - -void RoctracerActivityApi::popCorrelationID(CorrelationFlowType type) { -#ifdef HAS_ROCTRACER - if (!singleton().externalCorrelationEnabled_) { - return; - } - // placeholder -#endif -} - -void RoctracerActivityApi::setMaxBufferSize(int size) { - maxGpuBufferCount_ = 1 + size / kBufSize; -} - -int RoctracerActivityApi::processActivities( - ActivityLogger& logger) { - // Find offset to map from monotonic clock to system clock. - // This will break time-ordering of events but is status quo. - - timespec t0, t1, t00; - clock_gettime(CLOCK_REALTIME, &t0); - clock_gettime(CLOCK_MONOTONIC, &t1); - clock_gettime(CLOCK_REALTIME, &t00); - - const timestamp_t toffset = (timespec_to_ns(t0) >> 1) + (timespec_to_ns(t00) >> 1) - timespec_to_ns(t1); - - int count = 0; - - // Basic Api calls - - for (auto &item : rows_) { - GenericTraceActivity a; - a.startTime = (item.begin + toffset) / 1000; - a.endTime = (item.end + toffset) / 1000; - a.id = item.id; - a.device = item.pid; - a.resource = item.tid; - a.activityType = ActivityType::CUDA_RUNTIME; - a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); - a.flow.id = item.id; - a.flow.type = kLinkAsyncCpuGpu; - a.flow.start = true; - - logger.handleGenericActivity(a); - ++count; - } - - // Malloc/Free calls - for (auto &item : mallocRows_) { - GenericTraceActivity a; - a.startTime = (item.begin + toffset) / 1000; - a.endTime = (item.end + toffset) / 1000; - a.id = item.id; - a.device = item.pid; - a.resource = item.tid; - a.activityType = ActivityType::CUDA_RUNTIME; - a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); - a.flow.id = item.id; - a.flow.type = kLinkAsyncCpuGpu; - a.flow.start = true; - - a.addMetadata("ptr", item.ptr); - if (item.cid == HIP_API_ID_hipMalloc) { - a.addMetadata("size", item.size); - } - - logger.handleGenericActivity(a); - ++count; - } - - // HipMemcpy calls - for (auto &item : copyRows_) { - GenericTraceActivity a; - a.startTime = (item.begin + toffset) / 1000; - a.endTime = (item.end + toffset) / 1000; - a.id = item.id; - a.device = item.pid; - a.resource = item.tid; - a.activityType = ActivityType::CUDA_RUNTIME; - a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); - a.flow.id = item.id; - a.flow.type = kLinkAsyncCpuGpu; - a.flow.start = true; - - a.addMetadata("src", item.src); - a.addMetadata("dst", item.dst); - a.addMetadata("size", item.size); - a.addMetadata("kind", item.kind); - if ((item.cid == HIP_API_ID_hipMemcpyAsync) || (item.cid == HIP_API_ID_hipMemcpyWithStream)) { - a.addMetadata("stream", fmt::format("{}", reinterpret_cast(item.stream))); - } - - logger.handleGenericActivity(a); - ++count; - } - - // Kernel Launch Api calls - - for (auto &item : kernelRows_) { - GenericTraceActivity a; - a.startTime = (item.begin + toffset) / 1000; - a.endTime = (item.end + toffset) / 1000; - a.id = item.id; - a.device = item.pid; - a.resource = item.tid; - a.activityType = ActivityType::CUDA_RUNTIME; - a.activityName = std::string(roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, item.cid, 0)); - a.flow.id = item.id; - a.flow.type = kLinkAsyncCpuGpu; - a.flow.start = true; - - if (item.functionAddr != nullptr) { - a.addMetadataQuoted( - "kernel", demangle(hipKernelNameRefByPtr(item.functionAddr, item.stream))); - } - else if (item.function != nullptr) { - a.addMetadataQuoted( - "kernel", demangle(hipKernelNameRef(item.function))); - } - a.addMetadata("grid dim", fmt::format("[{}, {}, {}]", item.gridX, item.gridY, item.gridZ)); - a.addMetadata("block dim", fmt::format("[{}, {}, {}]", item.workgroupX, item.workgroupY, item.workgroupZ)); - a.addMetadata("shared size", item.groupSegmentSize); - a.addMetadata("stream", fmt::format("{}", reinterpret_cast(item.stream))); - - // Stash launches to tie to the async ops - kernelLaunches_[a.id] = a; - - // Stash kernel names to tie to the async ops - std::string name; - if (item.functionAddr != nullptr) { - name = demangle(hipKernelNameRefByPtr(item.functionAddr, item.stream)); - } - else if (item.function != nullptr) { - name = demangle(hipKernelNameRef(item.function)); - } - if (!name.empty()) { - uint32_t string_id = reverseStrings_[name]; - if (string_id == 0) { - string_id = nextStringId_++; - reverseStrings_[name] = string_id; - strings_[string_id] = name; - } - kernelNames_[item.id] = string_id; - } - - logger.handleGenericActivity(a); - ++count; - } - - // Async Ops - - for (auto& buffer : *gpuTraceBuffers_) { - const roctracer_record_t* record = (const roctracer_record_t*)(buffer.data); - const roctracer_record_t* end_record = (const roctracer_record_t*)(buffer.data + buffer.validSize); - GenericTraceActivity a; - - while (record < end_record) { - if ((record->domain == ACTIVITY_DOMAIN_HIP_API) && (loggedIds_.contains(record->op))) { - const char *name = roctracer_op_string(record->domain, record->op, record->kind); - a.device = record->process_id; - a.resource = record->thread_id; - - a.startTime = (record->begin_ns + toffset) / 1000; - a.endTime = (record->end_ns + toffset) / 1000; - a.id = record->correlation_id; - - a.activityType = ActivityType::CUDA_RUNTIME; - a.activityName = std::string(name); - a.flow.id = record->correlation_id; - a.flow.type = kLinkAsyncCpuGpu; - a.flow.start = true; - - logger.handleGenericActivity(a); - ++count; - } - else if (record->domain == ACTIVITY_DOMAIN_HCC_OPS) { - // Overlay launch metadata for kernels - auto kit = kernelLaunches_.find(record->correlation_id); - if (kit != kernelLaunches_.end()) { - a = (*kit).second; - } - - const char *name = roctracer_op_string(record->domain, record->op, record->kind); - a.device = record->device_id; - a.resource = record->queue_id; - - a.startTime = (record->begin_ns + toffset) / 1000; - a.endTime = (record->end_ns + toffset) / 1000; - a.id = record->correlation_id; - - a.activityType = ActivityType::CONCURRENT_KERNEL; - a.activityName = std::string(name); - a.flow.id = record->correlation_id; - a.flow.type = kLinkAsyncCpuGpu; - - auto it = kernelNames_.find(record->correlation_id); - if (it != kernelNames_.end()) { - a.activityName = strings_[it->second]; - } - - logger.handleGenericActivity(a); - ++count; - } - - roctracer_next_record(record, &record); - } - } - return count; -} - -void RoctracerActivityApi::clearActivities() { - gpuTraceBuffers_->clear(); - rows_.clear(); - kernelRows_.clear(); - copyRows_.clear(); - mallocRows_.clear(); - kernelLaunches_.clear(); -} - -void RoctracerActivityApi::api_callback(uint32_t domain, uint32_t cid, const void* callback_data, void* arg) -{ - RoctracerActivityApi *dis = &singleton(); - - if (domain == ACTIVITY_DOMAIN_HIP_API && dis->loggedIds_.contains(cid)) { - const hip_api_data_t* data = (const hip_api_data_t*)(callback_data); - - // Pack callbacks into row structures - - static timespec timestamp; // FIXME verify thread safety - - if (data->phase == ACTIVITY_API_PHASE_ENTER) { - clock_gettime(CLOCK_MONOTONIC, ×tamp); // record proper clock - } - else { // (data->phase == ACTIVITY_API_PHASE_EXIT) - timespec endTime; - timespec startTime { timestamp }; - clock_gettime(CLOCK_MONOTONIC, &endTime); // record proper clock - - switch (cid) { - case HIP_API_ID_hipLaunchKernel: - case HIP_API_ID_hipExtLaunchKernel: - case HIP_API_ID_hipLaunchCooperativeKernel: // Should work here - { - auto &args = data->args.hipLaunchKernel; - dis->kernelRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - args.function_address, - nullptr, - args.numBlocks.x, - args.numBlocks.y, - args.numBlocks.z, - args.dimBlocks.x, - args.dimBlocks.y, - args.dimBlocks.z, - args.sharedMemBytes, - args.stream - ); - } - break; - case HIP_API_ID_hipHccModuleLaunchKernel: - case HIP_API_ID_hipModuleLaunchKernel: - case HIP_API_ID_hipExtModuleLaunchKernel: - { - auto &args = data->args.hipModuleLaunchKernel; - dis->kernelRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - nullptr, - args.f, - args.gridDimX, - args.gridDimY, - args.gridDimZ, - args.blockDimX, - args.blockDimY, - args.blockDimZ, - args.sharedMemBytes, - args.stream - ); - } - break; - case HIP_API_ID_hipLaunchCooperativeKernelMultiDevice: - case HIP_API_ID_hipExtLaunchMultiKernelMultiDevice: -#if 0 - { - auto &args = data->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList__val; - dis->kernelRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - args.function_address, - nullptr, - args.numBlocks.x, - args.numBlocks.y, - args.numBlocks.z, - args.dimBlocks.x, - args.dimBlocks.y, - args.dimBlocks.z, - args.sharedMemBytes, - args.stream - ); - } -#endif - break; - case HIP_API_ID_hipMalloc: - dis->mallocRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - data->args.hipMalloc.ptr__val, - data->args.hipMalloc.size - ); - break; - case HIP_API_ID_hipFree: - dis->mallocRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - data->args.hipFree.ptr, - 0 - ); - break; - case HIP_API_ID_hipMemcpy: - { - auto &args = data->args.hipMemcpy; - dis->copyRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - args.src, - args.dst, - args.sizeBytes, - args.kind, - static_cast(0) // use placeholder? - ); - } - break; - case HIP_API_ID_hipMemcpyAsync: - case HIP_API_ID_hipMemcpyWithStream: - { - auto &args = data->args.hipMemcpyAsync; - dis->copyRows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime), - args.src, - args.dst, - args.sizeBytes, - args.kind, - args.stream - ); - } - break; - default: - dis->rows_.emplace_back(data->correlation_id, - domain, - cid, - processId(), - systemThreadId(), - timespec_to_ns(startTime), - timespec_to_ns(endTime) - ); - break; - } - } - } -} - -void RoctracerActivityApi::activity_callback(const char* begin, const char* end, void* arg) -{ - size_t size = end - begin; - uint8_t *buffer = (uint8_t*) malloc(size); - auto &gpuTraceBuffers = singleton().gpuTraceBuffers_; - memcpy(buffer, begin, size); - gpuTraceBuffers->emplace_back(buffer, size); -} - -void RoctracerActivityApi::enableActivities( - const std::set& selected_activities) { -#ifdef HAS_ROCTRACER - if (!registered_) { - roctracer_set_properties(ACTIVITY_DOMAIN_HIP_API, nullptr); // Magic encantation - - // Set some api calls to ignore - loggedIds_.setInvertMode(true); // Omit the specified api - loggedIds_.add("hipGetDevice"); - loggedIds_.add("hipSetDevice"); - loggedIds_.add("hipGetLastError"); - loggedIds_.add("__hipPushCallConfiguration"); - loggedIds_.add("__hipPopCallConfiguration"); - loggedIds_.add("hipCtxSetCurrent"); - loggedIds_.add("hipEventRecord"); - loggedIds_.add("hipEventQuery"); - loggedIds_.add("hipGetDeviceProperties"); - loggedIds_.add("hipPeekAtLastError"); - loggedIds_.add("hipModuleGetFunction"); - loggedIds_.add("hipEventCreateWithFlags"); - - // Enable API callbacks - if (loggedIds_.invertMode() == true) { - // exclusion list - enable entire domain and turn off things in list - roctracer_enable_domain_callback(ACTIVITY_DOMAIN_HIP_API, api_callback, nullptr); - const std::unordered_map &filter = loggedIds_.filterList(); - for (auto it = filter.begin(); it != filter.end(); ++it) { - roctracer_disable_op_callback(ACTIVITY_DOMAIN_HIP_API, it->first); - } - } - else { - // inclusion list - only enable things in the list - const std::unordered_map &filter = loggedIds_.filterList(); - roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HIP_API); - for (auto it = filter.begin(); it != filter.end(); ++it) { - roctracer_enable_op_callback(ACTIVITY_DOMAIN_HIP_API, it->first, api_callback, nullptr); - } - } - //roctracer_enable_domain_callback(ACTIVITY_DOMAIN_ROCTX, api_callback, nullptr); - - // Allocate default tracing pool - roctracer_properties_t properties; - memset(&properties, 0, sizeof(roctracer_properties_t)); - properties.buffer_size = 0x1000; - roctracer_open_pool(&properties); - - // Enable async op collection - roctracer_properties_t hcc_cb_properties; - memset(&hcc_cb_properties, 0, sizeof(roctracer_properties_t)); - hcc_cb_properties.buffer_size = 0x4000; - hcc_cb_properties.buffer_callback_fun = activity_callback; - roctracer_open_pool_expl(&hcc_cb_properties, &hccPool_); - roctracer_enable_domain_activity_expl(ACTIVITY_DOMAIN_HCC_OPS, hccPool_); - - registered_ = true; - } - - for (const auto& activity : selected_activities) { - if (activity == ActivityType::EXTERNAL_CORRELATION) { - externalCorrelationEnabled_ = true; - } - } - - roctracer_start(); -#endif -} - -void RoctracerActivityApi::disableActivities( - const std::set& selected_activities) { -#ifdef HAS_ROCTRACER - roctracer_stop(); - roctracer_flush_activity_expl(hccPool_); - - for (const auto& activity : selected_activities) { - if (activity == ActivityType::EXTERNAL_CORRELATION) { - externalCorrelationEnabled_ = false; - } - } -#endif -} - -void RoctracerActivityApi::endTracing() { - if (registered_ == true) { - roctracer_disable_domain_callback(ACTIVITY_DOMAIN_HIP_API); - //roctracer_disable_domain_callback(ACTIVITY_DOMAIN_ROCTX); - - roctracer_disable_domain_activity(ACTIVITY_DOMAIN_HCC_OPS); - roctracer_close_pool_expl(hccPool_); - } -} - - -ApiIdList::ApiIdList() -: invert_(true) -{ -} - -void ApiIdList::add(std::string apiName) -{ - uint32_t cid = 0; - if (roctracer_op_code(ACTIVITY_DOMAIN_HIP_API, apiName.c_str(), &cid, nullptr) == ROCTRACER_STATUS_SUCCESS) { - filter_[cid] = 1; - } -} -void ApiIdList::remove(std::string apiName) -{ - uint32_t cid = 0; - if (roctracer_op_code(ACTIVITY_DOMAIN_HIP_API, apiName.c_str(), &cid, nullptr) == ROCTRACER_STATUS_SUCCESS) { - filter_.erase(cid); - } -} - -bool ApiIdList::loadUserPrefs() -{ - // placeholder - return false; -} -bool ApiIdList::contains(uint32_t apiId) -{ - return (filter_.find(apiId) != filter_.end()) ? !invert_ : invert_; // XOR -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h deleted file mode 100644 index 28280253e7c8426e85c11d679785bcd74fa2a0c7..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h +++ /dev/null @@ -1,171 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef HAS_ROCTRACER -#include -#include -#include -#include -#include -#endif - -#include "ActivityType.h" -#include "GenericTraceActivity.h" -#include "RoctracerActivityBuffer.h" - - -namespace KINETO_NAMESPACE { - -using namespace libkineto; - -class ApiIdList -{ -public: - ApiIdList(); - bool invertMode() { return invert_; } - void setInvertMode(bool invert) { invert_ = invert; } - void add(std::string apiName); - void remove(std::string apiName); - bool loadUserPrefs(); - bool contains(uint32_t apiId); - const std::unordered_map &filterList() { return filter_; } - -private: - std::unordered_map filter_; - bool invert_; -}; - -struct roctracerRow { - roctracerRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid - , uint32_t tid, uint64_t begin, uint64_t end) - : id(id), domain(domain), cid(cid), pid(pid), tid(tid), begin(begin), end(end) {} - uint64_t id; // correlation_id - uint32_t domain; - uint32_t cid; - uint32_t pid; - uint32_t tid; - uint64_t begin; - uint64_t end; -}; - -struct kernelRow : public roctracerRow { - kernelRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid - , uint32_t tid, uint64_t begin, uint64_t end - , const void *faddr, hipFunction_t function - , unsigned int gx, unsigned int gy, unsigned int gz - , unsigned int wx, unsigned int wy, unsigned int wz - , size_t gss, hipStream_t stream) - : roctracerRow(id, domain, cid, pid, tid, begin, end), functionAddr(faddr) - , function(function), gridX(gx), gridY(gy), gridZ(gz) - , workgroupX(wx), workgroupY(wy), workgroupZ(wz), groupSegmentSize(gss) - , stream(stream) {} - const void* functionAddr; - hipFunction_t function; - unsigned int gridX; - unsigned int gridY; - unsigned int gridZ; - unsigned int workgroupX; - unsigned int workgroupY; - unsigned int workgroupZ; - size_t groupSegmentSize; - hipStream_t stream; -}; - -struct copyRow : public roctracerRow { - copyRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid - , uint32_t tid, uint64_t begin, uint64_t end - , const void* src, const void *dst, size_t size, hipMemcpyKind kind - , hipStream_t stream) - : roctracerRow(id, domain, cid, pid, tid, begin, end) - , src(src), dst(dst), size(size), kind(kind), stream(stream) {} - const void *src; - const void *dst; - size_t size; - hipMemcpyKind kind; - hipStream_t stream; -}; - -struct mallocRow : public roctracerRow { - mallocRow(uint64_t id, uint32_t domain, uint32_t cid, uint32_t pid - , uint32_t tid, uint64_t begin, uint64_t end - , const void* ptr, size_t size) - : roctracerRow(id, domain, cid, pid, tid, begin, end) - , ptr(ptr), size(size) {} - const void *ptr; - size_t size; -}; - - -class RoctracerActivityApi { - public: - enum CorrelationFlowType { - Default, - User - }; - - RoctracerActivityApi(); - RoctracerActivityApi(const RoctracerActivityApi&) = delete; - RoctracerActivityApi& operator=(const RoctracerActivityApi&) = delete; - - virtual ~RoctracerActivityApi(); - - static RoctracerActivityApi& singleton(); - - static void pushCorrelationID(int id, CorrelationFlowType type); - static void popCorrelationID(CorrelationFlowType type); - - void enableActivities( - const std::set& selected_activities); - void disableActivities( - const std::set& selected_activities); - void clearActivities(); - - int processActivities(ActivityLogger& logger); - - void setMaxBufferSize(int size); - - std::atomic_bool stopCollection{false}; - - private: - bool registered_{false}; - void endTracing(); - -#ifdef HAS_ROCTRACER - roctracer_pool_t *hccPool_{NULL}; - static void api_callback(uint32_t domain, uint32_t cid, const void* callback_data, void* arg); - static void activity_callback(const char* begin, const char* end, void* arg); - - //Name cache - uint32_t nextStringId_{2}; - std::map strings_; - std::map reverseStrings_; - std::map kernelNames_; - - ApiIdList loggedIds_; - - // Api callback data - std::deque rows_; - std::deque kernelRows_; - std::deque copyRows_; - std::deque mallocRows_; - std::map kernelLaunches_; -#endif - - int maxGpuBufferCount_{0}; - std::unique_ptr> gpuTraceBuffers_; - bool externalCorrelationEnabled_{true}; -}; - -} // namespace KINETO_NAMESPACE - diff --git a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h b/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h deleted file mode 100644 index cd8a5709a841b7c988ab3f2d1f3108d693343584..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h +++ /dev/null @@ -1,30 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include - -namespace KINETO_NAMESPACE { - -class RoctracerActivityBuffer { - public: - // data must be allocated using malloc. - // Ownership is transferred to this object. - RoctracerActivityBuffer(uint8_t* data, size_t validSize) - : data(data), validSize(validSize) {} - - ~RoctracerActivityBuffer() { - free(data); - } - - // Allocated by malloc - uint8_t* data{nullptr}; - - // Number of bytes used - size_t validSize; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/SampleListener.h b/plugins/tensorboard-plugins/libkineto/src/SampleListener.h deleted file mode 100644 index bff86ad122a051d4f3dfdbdd329a3b63d93a7c77..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/SampleListener.h +++ /dev/null @@ -1,146 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include - -namespace KINETO_NAMESPACE { - -class Config; - -class SampleValue { - public: - template - explicit SampleValue(T v) { - init(v); - } - - SampleValue(const SampleValue&) = default; - SampleValue& operator=(const SampleValue&) = delete; - SampleValue(SampleValue&&) = default; - SampleValue& operator=(SampleValue&&) = default; - - bool isInt() const { - return type_ == INT64; - } - - int64_t getInt() const { - assert(isInt()); - return int_; - } - - bool isDouble() const { - return type_ == DOUBLE; - } - - double getDouble() const { - assert(isDouble()); - return dbl_; - } - - inline void operator*=(double x) { - assert(isDouble() || isInt()); - if (isDouble()) { - dbl_ *= x; - } else { - int_ = std::round(int_ * x); - } - } - - inline bool operator<(const SampleValue& o) const { - if (type_ != o.type_) { - return type_ < o.type_; - } else if (type_ == INT64) { - return int_ < o.int_; - } else if (type_ == DOUBLE) { - return dbl_ < o.dbl_; - } - assert(false); - return true; - } - - void print(std::ostream& s) const { - if (type_ == INT64) { - s << int_; - } else if (type_ == DOUBLE) { - s << dbl_; - } else { - assert(false); - } - } - - private: - enum Type { INT64, DOUBLE }; - - template - void init(T v); - - Type type_{INT64}; - union { - int64_t int_{0}; - double dbl_; - }; -}; - -template <> -inline void SampleValue::init(uint64_t v) { - int_ = v, type_ = INT64; -} -template <> -inline void SampleValue::init(int64_t v) { - int_ = v, type_ = INT64; -} -template <> -inline void SampleValue::init(int v) { - int_ = v, type_ = INT64; -} -template <> -inline void SampleValue::init(double v) { - dbl_ = v, type_ = DOUBLE; -} - -inline std::ostream& operator<<(std::ostream& out, const SampleValue& s) { - s.print(out); - return out; -} - -using PercentileList = std::vector>; - -struct Stat { - const std::string& name; - const PercentileList percentileValues; - SampleValue total; -}; - -struct Sample { - Sample(int stats_count) { - stats.reserve(stats_count); - } - - // Offset in milliseconds from first sample in report - int deltaMsec; - std::vector stats; -}; - -// Inherit from this to be notified of samples -class SampleListener { - public: - SampleListener(const SampleListener&) = delete; - SampleListener& operator=(const SampleListener&) = delete; - - virtual ~SampleListener(){}; - - // Report bucketed & aggregated values for event - virtual void handleSample(int device, const Sample& sample, bool from_new_version) = 0; - - virtual void update(const Config& config) = 0; - - protected: - SampleListener() = default; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/ScopeExit.h b/plugins/tensorboard-plugins/libkineto/src/ScopeExit.h deleted file mode 100644 index b9a6bc83ef942c7fb0e4b198b0396e5d75aa5a3a..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ScopeExit.h +++ /dev/null @@ -1,29 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -// Implement a simple scope handler allowing a function to release -// resources when an error or exception occurs - -template -class ScopeExit { - public: - explicit ScopeExit(T t) : t(t) {} - ~ScopeExit() { - t(); - } - T t; -}; - -template -ScopeExit makeScopeExit(T t) { - return ScopeExit(t); -}; - -// Add a level of indirection so __LINE__ is expanded -#define __kINETO_CONCAT(name, line) name##line -#define ANON_VAR(name, line) __kINETO_CONCAT(name, line) - -#define SCOPE_EXIT(func) \ - const auto ANON_VAR(SCOPE_BLOCK, __LINE__) = \ - makeScopeExit([=]() { func; }) diff --git a/plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp b/plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp deleted file mode 100644 index 0f67d54d58512aa47b05aed69748a6894aa06b1c..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp +++ /dev/null @@ -1,203 +0,0 @@ -#include "ThreadUtil.h" - -#ifndef _MSC_VER -#include -#include -#include -#include -#else // _MSC_VER -#include -#include -#define WIN32_LEAN_AND_MEAN -#define NOGDI -#include -#include -#undef ERROR -#endif // _MSC_VER - -#ifdef __ANDROID__ -#include -#endif - -#include -#include -#include - -namespace libkineto { - -namespace { -thread_local int32_t _pid = 0; -thread_local int32_t _tid = 0; -thread_local int32_t _sysTid = 0; -} - -int32_t processId() { - if (!_pid) { -#ifndef _MSC_VER - _pid = (int32_t)getpid(); -#else - _pid = (int32_t)GetCurrentProcessId(); -#endif - } - return _pid; -} - -int32_t systemThreadId() { - if (!_sysTid) { -#ifdef __APPLE__ - _sysTid = (int32_t)syscall(SYS_thread_selfid); -#elif defined _MSC_VER - _sysTid = (int32_t)GetCurrentThreadId(); -#else - _sysTid = (int32_t)syscall(SYS_gettid); -#endif - } - return _sysTid; -} - -int32_t threadId() { - if (!_tid) { -#ifdef __APPLE__ - uint64_t tid; - pthread_threadid_np(nullptr, &tid); - _tid = tid; -#elif defined _MSC_VER - _tid = (int32_t)GetCurrentThreadId(); -#else - pthread_t pth = pthread_self(); - int32_t* ptr = reinterpret_cast(&pth); - _tid = *ptr; -#endif - } - return _tid; -} - -namespace { -static constexpr size_t kMaxThreadNameLength = 16; - -static constexpr const char* basename(const char* s, int off = 0) { - return !s[off] - ? s - : s[off] == '/' ? basename(&s[off + 1]) : basename(s, off + 1); -} -#if defined(_MSC_VER) -void *getKernel32Func(const char* procName) { - return GetProcAddress(GetModuleHandleA("KERNEL32.DLL"), procName); -} -#endif -} - -bool setThreadName(const std::string& name) { -#ifdef __APPLE__ - return 0 == pthread_setname_np(name.c_str()); -#elif defined _MSC_VER - // Per https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-setthreaddescription - // Use runtime linking to set thread description - static auto _SetThreadDescription = reinterpret_cast(getKernel32Func("SetThreadDescription")); - if (!_SetThreadDescription) { - return false; - } - std::wstring_convert> conv; - std::wstring wname = conv.from_bytes(name); - HRESULT hr = _SetThreadDescription(GetCurrentThread(), wname.c_str()); - return SUCCEEDED(hr); -#else - return 0 == pthread_setname_np(pthread_self(), name.c_str()); -#endif -} - -std::string getThreadName() { -#ifndef _MSC_VER - char buf[kMaxThreadNameLength] = ""; - if ( -#ifndef __ANDROID__ - pthread_getname_np(pthread_self(), buf, kMaxThreadNameLength) != 0 -#else - prctl(PR_GET_NAME, buf, kMaxThreadNameLength) != 0 -#endif - ) { - return "Unknown"; - } - return buf; -#else // _MSC_VER - static auto _GetThreadDescription = reinterpret_cast(getKernel32Func("GetThreadDescription")); - if (!_GetThreadDescription) { - return "Unknown"; - } - PWSTR data; - HRESULT hr = _GetThreadDescription(GetCurrentThread(), &data); - if (!SUCCEEDED(hr)) { - return ""; - } - std::wstring_convert> conv; - std::string name = conv.to_bytes(data); - LocalFree(data); - return name; -#endif -} - -// Linux: -// Extract process name from /proc/pid/cmdline. This does not have -// the 16 character limit that /proc/pid/status and /prod/pid/comm has. -std::string processName(int32_t pid) { -#ifdef __linux__ - FILE* cmdfile = fopen(fmt::format("/proc/{}/cmdline", pid).c_str(), "r"); - if (cmdfile != nullptr) { - char* command = nullptr; - int scanned = fscanf(cmdfile, "%ms", &command); - fclose(cmdfile); - if (scanned > 0 && command) { - std::string ret(basename(command)); - free(command); - return ret; - } - } - std::cerr << "Failed to read process name for pid " << pid << std::endl; -#endif - return ""; -} - -// Max number of parent pids to collect, just for extra safeguarding. -constexpr int kMaxParentPids = 10; - -// Return a pair of -static std::pair parentPidAndCommand(int32_t pid) { -#ifdef __linux__ - FILE* statfile = fopen(fmt::format("/proc/{}/stat", pid).c_str(), "r"); - if (statfile == nullptr) { - return std::make_pair(0, ""); - } - int32_t parent_pid; - char* command = nullptr; - int scanned = fscanf(statfile, "%*d (%m[^)]) %*c %d", &command, &parent_pid); - fclose(statfile); - std::pair ret; - if (scanned == 2) { - ret = std::make_pair(parent_pid, std::string(command)); - } else { - std::cerr << "Failed to parse /proc/" << pid << "/stat" << std::endl; - ret = std::make_pair(0, ""); - } - - // The 'm' character in the format tells fscanf to allocate memory - // for the parsed string, which we need to free here. - free(command); - return ret; -#else - return std::make_pair(0, ""); -#endif -} - -std::vector> pidCommandPairsOfAncestors() { - std::vector> pairs; - pairs.reserve(kMaxParentPids + 1); - int32_t curr_pid = processId(); - for (int i = 0; i <= kMaxParentPids && curr_pid > 1; i++) { - std::pair ppid_and_comm = parentPidAndCommand(curr_pid); - pairs.push_back(std::make_pair(curr_pid, ppid_and_comm.second)); - curr_pid = ppid_and_comm.first; - } - return pairs; -} - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp b/plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp deleted file mode 100644 index 540a5ac8f97c8f38c7ee3d31ea285a3ab7c9f375..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include - -#ifndef _MSC_VER -extern "C" { -// This function is needed to avoid superfluous dependency on GNU OpenMP library when cuPTI is linked statically -// For more details see https://github.com/pytorch/pytorch/issues/51026 -__attribute__((weak)) int acc_get_device_type() { - throw std::runtime_error("Dummy implementation of acc_get_device_type is not supposed to be called!"); -} - -} // extern "C" -#endif diff --git a/plugins/tensorboard-plugins/libkineto/src/cupti_call.h b/plugins/tensorboard-plugins/libkineto/src/cupti_call.h deleted file mode 100644 index fd6ebae7691ed607867db5717248ba22f4efa5c0..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/cupti_call.h +++ /dev/null @@ -1,33 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -#ifdef HAS_CUPTI - -#include - -#define CUPTI_CALL(call) \ - [&]() -> CUptiResult { \ - CUptiResult _status_ = call; \ - if (_status_ != CUPTI_SUCCESS) { \ - const char* _errstr_ = nullptr; \ - cuptiGetResultString(_status_, &_errstr_); \ - LOG(WARNING) << fmt::format( \ - "function {} failed with error {} ({})", \ - #call, \ - _errstr_, \ - (int)_status_); \ - } \ - return _status_; \ - }() - -#define CUPTI_CALL_NOWARN(call) call - -#else - -#define CUPTI_CALL(call) call -#define CUPTI_CALL_NOWARN(call) call - -#endif // HAS_CUPTI diff --git a/plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp b/plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp deleted file mode 100644 index 4535273a277e04b0b6f98b539df82955ef62468f..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp +++ /dev/null @@ -1,502 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "cupti_strings.h" - -namespace libkineto { - -const char* memcpyKindString( - CUpti_ActivityMemcpyKind kind) { - switch (kind) { - case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD: - return "HtoD"; - case CUPTI_ACTIVITY_MEMCPY_KIND_DTOH: - return "DtoH"; - case CUPTI_ACTIVITY_MEMCPY_KIND_HTOA: - return "HtoA"; - case CUPTI_ACTIVITY_MEMCPY_KIND_ATOH: - return "AtoH"; - case CUPTI_ACTIVITY_MEMCPY_KIND_ATOA: - return "AtoA"; - case CUPTI_ACTIVITY_MEMCPY_KIND_ATOD: - return "AtoD"; - case CUPTI_ACTIVITY_MEMCPY_KIND_DTOA: - return "DtoA"; - case CUPTI_ACTIVITY_MEMCPY_KIND_DTOD: - return "DtoD"; - case CUPTI_ACTIVITY_MEMCPY_KIND_HTOH: - return "HtoH"; - case CUPTI_ACTIVITY_MEMCPY_KIND_PTOP: - return "PtoP"; - default: - break; - } - return ""; -} - -const char* memoryKindString( - CUpti_ActivityMemoryKind kind) { - switch (kind) { - case CUPTI_ACTIVITY_MEMORY_KIND_UNKNOWN: - return "Unknown"; - case CUPTI_ACTIVITY_MEMORY_KIND_PAGEABLE: - return "Pageable"; - case CUPTI_ACTIVITY_MEMORY_KIND_PINNED: - return "Pinned"; - case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE: - return "Device"; - case CUPTI_ACTIVITY_MEMORY_KIND_ARRAY: - return "Array"; - case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED: - return "Managed"; - case CUPTI_ACTIVITY_MEMORY_KIND_DEVICE_STATIC: - return "Device Static"; - case CUPTI_ACTIVITY_MEMORY_KIND_MANAGED_STATIC: - return "Managed Static"; - case CUPTI_ACTIVITY_MEMORY_KIND_FORCE_INT: - return "Force Int"; - default: - return "Unrecognized"; - } -} - -const char* overheadKindString( - CUpti_ActivityOverheadKind kind) { - switch (kind) { - case CUPTI_ACTIVITY_OVERHEAD_UNKNOWN: - return "Unknown"; - case CUPTI_ACTIVITY_OVERHEAD_DRIVER_COMPILER: - return "Driver Compiler"; - case CUPTI_ACTIVITY_OVERHEAD_CUPTI_BUFFER_FLUSH: - return "Buffer Flush"; - case CUPTI_ACTIVITY_OVERHEAD_CUPTI_INSTRUMENTATION: - return "Instrumentation"; - case CUPTI_ACTIVITY_OVERHEAD_CUPTI_RESOURCE: - return "Resource"; - case CUPTI_ACTIVITY_OVERHEAD_FORCE_INT: - return "Force Int"; - default: - return "Unrecognized"; - } -} - - - -static const char* runtimeCbidNames[] = { - "INVALID", - "cudaDriverGetVersion", - "cudaRuntimeGetVersion", - "cudaGetDeviceCount", - "cudaGetDeviceProperties", - "cudaChooseDevice", - "cudaGetChannelDesc", - "cudaCreateChannelDesc", - "cudaConfigureCall", - "cudaSetupArgument", - "cudaGetLastError", - "cudaPeekAtLastError", - "cudaGetErrorString", - "cudaLaunch", - "cudaFuncSetCacheConfig", - "cudaFuncGetAttributes", - "cudaSetDevice", - "cudaGetDevice", - "cudaSetValidDevices", - "cudaSetDeviceFlags", - "cudaMalloc", - "cudaMallocPitch", - "cudaFree", - "cudaMallocArray", - "cudaFreeArray", - "cudaMallocHost", - "cudaFreeHost", - "cudaHostAlloc", - "cudaHostGetDevicePointer", - "cudaHostGetFlags", - "cudaMemGetInfo", - "cudaMemcpy", - "cudaMemcpy2D", - "cudaMemcpyToArray", - "cudaMemcpy2DToArray", - "cudaMemcpyFromArray", - "cudaMemcpy2DFromArray", - "cudaMemcpyArrayToArray", - "cudaMemcpy2DArrayToArray", - "cudaMemcpyToSymbol", - "cudaMemcpyFromSymbol", - "cudaMemcpyAsync", - "cudaMemcpyToArrayAsync", - "cudaMemcpyFromArrayAsync", - "cudaMemcpy2DAsync", - "cudaMemcpy2DToArrayAsync", - "cudaMemcpy2DFromArrayAsync", - "cudaMemcpyToSymbolAsync", - "cudaMemcpyFromSymbolAsync", - "cudaMemset", - "cudaMemset2D", - "cudaMemsetAsync", - "cudaMemset2DAsync", - "cudaGetSymbolAddress", - "cudaGetSymbolSize", - "cudaBindTexture", - "cudaBindTexture2D", - "cudaBindTextureToArray", - "cudaUnbindTexture", - "cudaGetTextureAlignmentOffset", - "cudaGetTextureReference", - "cudaBindSurfaceToArray", - "cudaGetSurfaceReference", - "cudaGLSetGLDevice", - "cudaGLRegisterBufferObject", - "cudaGLMapBufferObject", - "cudaGLUnmapBufferObject", - "cudaGLUnregisterBufferObject", - "cudaGLSetBufferObjectMapFlags", - "cudaGLMapBufferObjectAsync", - "cudaGLUnmapBufferObjectAsync", - "cudaWGLGetDevice", - "cudaGraphicsGLRegisterImage", - "cudaGraphicsGLRegisterBuffer", - "cudaGraphicsUnregisterResource", - "cudaGraphicsResourceSetMapFlags", - "cudaGraphicsMapResources", - "cudaGraphicsUnmapResources", - "cudaGraphicsResourceGetMappedPointer", - "cudaGraphicsSubResourceGetMappedArray", - "cudaVDPAUGetDevice", - "cudaVDPAUSetVDPAUDevice", - "cudaGraphicsVDPAURegisterVideoSurface", - "cudaGraphicsVDPAURegisterOutputSurface", - "cudaD3D11GetDevice", - "cudaD3D11GetDevices", - "cudaD3D11SetDirect3DDevice", - "cudaGraphicsD3D11RegisterResource", - "cudaD3D10GetDevice", - "cudaD3D10GetDevices", - "cudaD3D10SetDirect3DDevice", - "cudaGraphicsD3D10RegisterResource", - "cudaD3D10RegisterResource", - "cudaD3D10UnregisterResource", - "cudaD3D10MapResources", - "cudaD3D10UnmapResources", - "cudaD3D10ResourceSetMapFlags", - "cudaD3D10ResourceGetSurfaceDimensions", - "cudaD3D10ResourceGetMappedArray", - "cudaD3D10ResourceGetMappedPointer", - "cudaD3D10ResourceGetMappedSize", - "cudaD3D10ResourceGetMappedPitch", - "cudaD3D9GetDevice", - "cudaD3D9GetDevices", - "cudaD3D9SetDirect3DDevice", - "cudaD3D9GetDirect3DDevice", - "cudaGraphicsD3D9RegisterResource", - "cudaD3D9RegisterResource", - "cudaD3D9UnregisterResource", - "cudaD3D9MapResources", - "cudaD3D9UnmapResources", - "cudaD3D9ResourceSetMapFlags", - "cudaD3D9ResourceGetSurfaceDimensions", - "cudaD3D9ResourceGetMappedArray", - "cudaD3D9ResourceGetMappedPointer", - "cudaD3D9ResourceGetMappedSize", - "cudaD3D9ResourceGetMappedPitch", - "cudaD3D9Begin", - "cudaD3D9End", - "cudaD3D9RegisterVertexBuffer", - "cudaD3D9UnregisterVertexBuffer", - "cudaD3D9MapVertexBuffer", - "cudaD3D9UnmapVertexBuffer", - "cudaThreadExit", - "cudaSetDoubleForDevice", - "cudaSetDoubleForHost", - "cudaThreadSynchronize", - "cudaThreadGetLimit", - "cudaThreadSetLimit", - "cudaStreamCreate", - "cudaStreamDestroy", - "cudaStreamSynchronize", - "cudaStreamQuery", - "cudaEventCreate", - "cudaEventCreateWithFlags", - "cudaEventRecord", - "cudaEventDestroy", - "cudaEventSynchronize", - "cudaEventQuery", - "cudaEventElapsedTime", - "cudaMalloc3D", - "cudaMalloc3DArray", - "cudaMemset3D", - "cudaMemset3DAsync", - "cudaMemcpy3D", - "cudaMemcpy3DAsync", - "cudaThreadSetCacheConfig", - "cudaStreamWaitEvent", - "cudaD3D11GetDirect3DDevice", - "cudaD3D10GetDirect3DDevice", - "cudaThreadGetCacheConfig", - "cudaPointerGetAttributes", - "cudaHostRegister", - "cudaHostUnregister", - "cudaDeviceCanAccessPeer", - "cudaDeviceEnablePeerAccess", - "cudaDeviceDisablePeerAccess", - "cudaPeerRegister", - "cudaPeerUnregister", - "cudaPeerGetDevicePointer", - "cudaMemcpyPeer", - "cudaMemcpyPeerAsync", - "cudaMemcpy3DPeer", - "cudaMemcpy3DPeerAsync", - "cudaDeviceReset", - "cudaDeviceSynchronize", - "cudaDeviceGetLimit", - "cudaDeviceSetLimit", - "cudaDeviceGetCacheConfig", - "cudaDeviceSetCacheConfig", - "cudaProfilerInitialize", - "cudaProfilerStart", - "cudaProfilerStop", - "cudaDeviceGetByPCIBusId", - "cudaDeviceGetPCIBusId", - "cudaGLGetDevices", - "cudaIpcGetEventHandle", - "cudaIpcOpenEventHandle", - "cudaIpcGetMemHandle", - "cudaIpcOpenMemHandle", - "cudaIpcCloseMemHandle", - "cudaArrayGetInfo", - "cudaFuncSetSharedMemConfig", - "cudaDeviceGetSharedMemConfig", - "cudaDeviceSetSharedMemConfig", - "cudaCreateTextureObject", - "cudaDestroyTextureObject", - "cudaGetTextureObjectResourceDesc", - "cudaGetTextureObjectTextureDesc", - "cudaCreateSurfaceObject", - "cudaDestroySurfaceObject", - "cudaGetSurfaceObjectResourceDesc", - "cudaMallocMipmappedArray", - "cudaGetMipmappedArrayLevel", - "cudaFreeMipmappedArray", - "cudaBindTextureToMipmappedArray", - "cudaGraphicsResourceGetMappedMipmappedArray", - "cudaStreamAddCallback", - "cudaStreamCreateWithFlags", - "cudaGetTextureObjectResourceViewDesc", - "cudaDeviceGetAttribute", - "cudaStreamDestroy", - "cudaStreamCreateWithPriority", - "cudaStreamGetPriority", - "cudaStreamGetFlags", - "cudaDeviceGetStreamPriorityRange", - "cudaMallocManaged", - "cudaOccupancyMaxActiveBlocksPerMultiprocessor", - "cudaStreamAttachMemAsync", - "cudaGetErrorName", - "cudaOccupancyMaxActiveBlocksPerMultiprocessor", - "cudaLaunchKernel", - "cudaGetDeviceFlags", - "cudaLaunch_ptsz", - "cudaLaunchKernel_ptsz", - "cudaMemcpy_ptds", - "cudaMemcpy2D_ptds", - "cudaMemcpyToArray_ptds", - "cudaMemcpy2DToArray_ptds", - "cudaMemcpyFromArray_ptds", - "cudaMemcpy2DFromArray_ptds", - "cudaMemcpyArrayToArray_ptds", - "cudaMemcpy2DArrayToArray_ptds", - "cudaMemcpyToSymbol_ptds", - "cudaMemcpyFromSymbol_ptds", - "cudaMemcpyAsync_ptsz", - "cudaMemcpyToArrayAsync_ptsz", - "cudaMemcpyFromArrayAsync_ptsz", - "cudaMemcpy2DAsync_ptsz", - "cudaMemcpy2DToArrayAsync_ptsz", - "cudaMemcpy2DFromArrayAsync_ptsz", - "cudaMemcpyToSymbolAsync_ptsz", - "cudaMemcpyFromSymbolAsync_ptsz", - "cudaMemset_ptds", - "cudaMemset2D_ptds", - "cudaMemsetAsync_ptsz", - "cudaMemset2DAsync_ptsz", - "cudaStreamGetPriority_ptsz", - "cudaStreamGetFlags_ptsz", - "cudaStreamSynchronize_ptsz", - "cudaStreamQuery_ptsz", - "cudaStreamAttachMemAsync_ptsz", - "cudaEventRecord_ptsz", - "cudaMemset3D_ptds", - "cudaMemset3DAsync_ptsz", - "cudaMemcpy3D_ptds", - "cudaMemcpy3DAsync_ptsz", - "cudaStreamWaitEvent_ptsz", - "cudaStreamAddCallback_ptsz", - "cudaMemcpy3DPeer_ptds", - "cudaMemcpy3DPeerAsync_ptsz", - "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", - "cudaMemPrefetchAsync", - "cudaMemPrefetchAsync_ptsz", - "cudaMemAdvise", - "cudaDeviceGetP2PAttribute", - "cudaGraphicsEGLRegisterImage", - "cudaEGLStreamConsumerConnect", - "cudaEGLStreamConsumerDisconnect", - "cudaEGLStreamConsumerAcquireFrame", - "cudaEGLStreamConsumerReleaseFrame", - "cudaEGLStreamProducerConnect", - "cudaEGLStreamProducerDisconnect", - "cudaEGLStreamProducerPresentFrame", - "cudaEGLStreamProducerReturnFrame", - "cudaGraphicsResourceGetMappedEglFrame", - "cudaMemRangeGetAttribute", - "cudaMemRangeGetAttributes", - "cudaEGLStreamConsumerConnectWithFlags", - "cudaLaunchCooperativeKernel", - "cudaLaunchCooperativeKernel_ptsz", - "cudaEventCreateFromEGLSync", - "cudaLaunchCooperativeKernelMultiDevice", - "cudaFuncSetAttribute", - "cudaImportExternalMemory", - "cudaExternalMemoryGetMappedBuffer", - "cudaExternalMemoryGetMappedMipmappedArray", - "cudaDestroyExternalMemory", - "cudaImportExternalSemaphore", - "cudaSignalExternalSemaphoresAsync", - "cudaSignalExternalSemaphoresAsync_ptsz", - "cudaWaitExternalSemaphoresAsync", - "cudaWaitExternalSemaphoresAsync_ptsz", - "cudaDestroyExternalSemaphore", - "cudaLaunchHostFunc", - "cudaLaunchHostFunc_ptsz", - "cudaGraphCreate", - "cudaGraphKernelNodeGetParams", - "cudaGraphKernelNodeSetParams", - "cudaGraphAddKernelNode", - "cudaGraphAddMemcpyNode", - "cudaGraphMemcpyNodeGetParams", - "cudaGraphMemcpyNodeSetParams", - "cudaGraphAddMemsetNode", - "cudaGraphMemsetNodeGetParams", - "cudaGraphMemsetNodeSetParams", - "cudaGraphAddHostNode", - "cudaGraphHostNodeGetParams", - "cudaGraphAddChildGraphNode", - "cudaGraphChildGraphNodeGetGraph", - "cudaGraphAddEmptyNode", - "cudaGraphClone", - "cudaGraphNodeFindInClone", - "cudaGraphNodeGetType", - "cudaGraphGetRootNodes", - "cudaGraphNodeGetDependencies", - "cudaGraphNodeGetDependentNodes", - "cudaGraphAddDependencies", - "cudaGraphRemoveDependencies", - "cudaGraphDestroyNode", - "cudaGraphInstantiate", - "cudaGraphLaunch", - "cudaGraphLaunch_ptsz", - "cudaGraphExecDestroy", - "cudaGraphDestroy", - "cudaStreamBeginCapture", - "cudaStreamBeginCapture_ptsz", - "cudaStreamIsCapturing", - "cudaStreamIsCapturing_ptsz", - "cudaStreamEndCapture", - "cudaStreamEndCapture_ptsz", - "cudaGraphHostNodeSetParams", - "cudaGraphGetNodes", - "cudaGraphGetEdges", - "cudaStreamGetCaptureInfo", - "cudaStreamGetCaptureInfo_ptsz", - "cudaGraphExecKernelNodeSetParams", - "cudaThreadExchangeStreamCaptureMode", - "cudaDeviceGetNvSciSyncAttributes", - "cudaOccupancyAvailableDynamicSMemPerBlock", - "cudaStreamSetFlags", - "cudaStreamSetFlags_ptsz", - "cudaGraphExecMemcpyNodeSetParams", - "cudaGraphExecMemsetNodeSetParams", - "cudaGraphExecHostNodeSetParams", - "cudaGraphExecUpdate", - "cudaGetFuncBySymbol", - "cudaCtxResetPersistingL2Cache", - "cudaGraphKernelNodeCopyAttributes", - "cudaGraphKernelNodeGetAttribute", - "cudaGraphKernelNodeSetAttribute", - "cudaStreamCopyAttributes", - "cudaStreamCopyAttributes_ptsz", - "cudaStreamGetAttribute", - "cudaStreamGetAttribute_ptsz", - "cudaStreamSetAttribute", - "cudaStreamSetAttribute_ptsz", - "cudaDeviceGetTexture1DLinearMaxWidth", - "cudaGraphUpload", - "cudaGraphUpload_ptsz", - "cudaGraphAddMemcpyNodeToSymbol", - "cudaGraphAddMemcpyNodeFromSymbol", - "cudaGraphAddMemcpyNode1D", - "cudaGraphMemcpyNodeSetParamsToSymbol", - "cudaGraphMemcpyNodeSetParamsFromSymbol", - "cudaGraphMemcpyNodeSetParams1D", - "cudaGraphExecMemcpyNodeSetParamsToSymbol", - "cudaGraphExecMemcpyNodeSetParamsFromSymbol", - "cudaGraphExecMemcpyNodeSetParams1D", - "cudaArrayGetSparseProperties", - "cudaMipmappedArrayGetSparseProperties", - "cudaGraphExecChildGraphNodeSetParams", - "cudaGraphAddEventRecordNode", - "cudaGraphEventRecordNodeGetEvent", - "cudaGraphEventRecordNodeSetEvent", - "cudaGraphAddEventWaitNode", - "cudaGraphEventWaitNodeGetEvent", - "cudaGraphEventWaitNodeSetEvent", - "cudaGraphExecEventRecordNodeSetEvent", - "cudaGraphExecEventWaitNodeSetEvent", - "cudaEventRecordWithFlags", - "cudaEventRecordWithFlags_ptsz", - "cudaDeviceGetDefaultMemPool", - "cudaMallocAsync", - "cudaMallocAsync_ptsz", - "cudaFreeAsync", - "cudaFreeAsync_ptsz", - "cudaMemPoolTrimTo", - "cudaMemPoolSetAttribute", - "cudaMemPoolGetAttribute", - "cudaMemPoolSetAccess", - "cudaArrayGetPlane", - "cudaMemPoolGetAccess", - "cudaMemPoolCreate", - "cudaMemPoolDestroy", - "cudaDeviceSetMemPool", - "cudaDeviceGetMemPool", - "cudaMemPoolExportToShareableHandle", - "cudaMemPoolImportFromShareableHandle", - "cudaMemPoolExportPointer", - "cudaMemPoolImportPointer", - "cudaMallocFromPoolAsync", - "cudaMallocFromPoolAsync_ptsz", - "cudaSignalExternalSemaphoresAsync", - "cudaSignalExternalSemaphoresAsync", - "cudaWaitExternalSemaphoresAsync", - "cudaWaitExternalSemaphoresAsync", - "cudaGraphAddExternalSemaphoresSignalNode", - "cudaGraphExternalSemaphoresSignalNodeGetParams", - "cudaGraphExternalSemaphoresSignalNodeSetParams", - "cudaGraphAddExternalSemaphoresWaitNode", - "cudaGraphExternalSemaphoresWaitNodeGetParams", - "cudaGraphExternalSemaphoresWaitNodeSetParams", - "cudaGraphExecExternalSemaphoresSignalNodeSetParams", - "cudaGraphExecExternalSemaphoresWaitNodeSetParams", - "SIZE" -}; - -const char* runtimeCbidName(CUpti_CallbackId cbid) { - constexpr int names_size = - sizeof(runtimeCbidNames) / sizeof(runtimeCbidNames[0]); - if (cbid < 0 || cbid >= names_size) { - return runtimeCbidNames[CUPTI_RUNTIME_TRACE_CBID_INVALID]; - } - return runtimeCbidNames[cbid]; -} - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/cupti_strings.h b/plugins/tensorboard-plugins/libkineto/src/cupti_strings.h deleted file mode 100644 index bbfebb983648005d8268d9a29d613d369d6a5384..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/cupti_strings.h +++ /dev/null @@ -1,14 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include - -namespace libkineto { - -const char* memoryKindString(CUpti_ActivityMemoryKind kind); -const char* memcpyKindString(CUpti_ActivityMemcpyKind kind); -const char* runtimeCbidName(CUpti_CallbackId cbid); -const char* overheadKindString(CUpti_ActivityOverheadKind kind); - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/init.cpp b/plugins/tensorboard-plugins/libkineto/src/init.cpp deleted file mode 100644 index 4e1022485ac5d17b5af1e0676b6a4595a138e1b5..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/init.cpp +++ /dev/null @@ -1,139 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include - -#include "ActivityProfilerProxy.h" -#include "Config.h" -#ifdef HAS_CUPTI -#include "CuptiCallbackApi.h" -#include "CuptiActivityApi.h" -#include "EventProfilerController.h" -#endif -#include "cupti_call.h" -#include "libkineto.h" - -#include "Logger.h" - -namespace KINETO_NAMESPACE { - -#ifdef HAS_CUPTI -static bool initialized = false; -static std::mutex initMutex; - -static void initProfilers( - CUpti_CallbackDomain /*domain*/, - CUpti_CallbackId /*cbid*/, - const CUpti_CallbackData* cbInfo) { - CUpti_ResourceData* d = (CUpti_ResourceData*)cbInfo; - CUcontext ctx = d->context; - - VLOG(0) << "CUDA Context created"; - std::lock_guard lock(initMutex); - - if (!initialized) { - libkineto::api().initProfilerIfRegistered(); - initialized = true; - VLOG(0) << "libkineto profilers activated"; - } - if (getenv("KINETO_DISABLE_EVENT_PROFILER") != nullptr) { - VLOG(0) << "Event profiler disabled via env var"; - } else { - ConfigLoader& config_loader = libkineto::api().configLoader(); - config_loader.initBaseConfig(); - EventProfilerController::start(ctx, config_loader); - } -} - -// Some models suffer from excessive instrumentation code gen -// on dynamic attach which can hang for more than 5+ seconds. -// If the workload was meant to be traced, preload the CUPTI -// to take the performance hit early on. -// https://docs.nvidia.com/cupti/r_main.html#r_overhead -static bool shouldPreloadCuptiInstrumentation() { - return getenv("PRELOAD_CUPTI_INSTRUMENTATION"); -} - -static void stopProfiler( - CUpti_CallbackDomain /*domain*/, - CUpti_CallbackId /*cbid*/, - const CUpti_CallbackData* cbInfo) { - CUpti_ResourceData* d = (CUpti_ResourceData*)cbInfo; - CUcontext ctx = d->context; - - LOG(INFO) << "CUDA Context destroyed"; - std::lock_guard lock(initMutex); - EventProfilerController::stop(ctx); -} -#endif // HAS_CUPTI - -} // namespace KINETO_NAMESPACE - -// Callback interface with CUPTI and library constructors -using namespace KINETO_NAMESPACE; -extern "C" { - -// Return true if no CUPTI errors occurred during init -bool libkineto_init(bool cpuOnly, bool logOnError) { - bool success = true; -#ifdef HAS_CUPTI - if (!cpuOnly) { - // libcupti will be lazily loaded on this call. - // If it is not available (e.g. CUDA is not installed), - // then this call will return an error and we just abort init. - auto& cbapi = CuptiCallbackApi::singleton(); - bool status = false; - - if (cbapi.initSuccess()){ - const CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RESOURCE; - status = cbapi.registerCallback( - domain, CuptiCallbackApi::RESOURCE_CONTEXT_CREATED, initProfilers); - status = status && cbapi.registerCallback( - domain, CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED, stopProfiler); - - if (status) { - status = cbapi.enableCallback( - domain, CuptiCallbackApi::RESOURCE_CONTEXT_CREATED); - status = status && cbapi.enableCallback( - domain, CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED); - } - } - - if (!cbapi.initSuccess() || !status) { - success = false; - cpuOnly = true; - if (logOnError) { - CUPTI_CALL(cbapi.getCuptiStatus()); - LOG(WARNING) << "CUPTI initialization failed - " - << "CUDA profiler activities will be missing"; - LOG(INFO) << "If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to " - << "https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti"; - } - } - } - - if (shouldPreloadCuptiInstrumentation()) { - CuptiActivityApi::forceLoadCupti(); - } -#endif // HAS_CUPTI - - ConfigLoader& config_loader = libkineto::api().configLoader(); - libkineto::api().registerProfiler( - std::make_unique(cpuOnly, config_loader)); - - return success; -} - -// The cuda driver calls this function if the CUDA_INJECTION64_PATH environment -// variable is set -int InitializeInjection(void) { - LOG(INFO) << "Injection mode: Initializing libkineto"; - libkineto_init(false /*cpuOnly*/, true /*logOnError*/); - return 1; -} - -void suppressLibkinetoLogMessages() { - SET_LOG_SEVERITY_LEVEL(ERROR); -} - -} // extern C diff --git a/plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp b/plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp deleted file mode 100644 index 9a622e4f5e5cfd54848cb8c6dc05b98da2fb6011..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "libkineto.h" - -#include "ConfigLoader.h" -#include "ThreadUtil.h" - -namespace libkineto { - -LibkinetoApi& api() { - static LibkinetoApi instance(ConfigLoader::instance()); - return instance; -} - -void LibkinetoApi::initClientIfRegistered() { - if (client_) { - if (clientRegisterThread_ != threadId()) { - fprintf( - stderr, - "ERROR: External init callback must run in same thread as registerClient " - "(%d != %d)\n", - threadId(), - (int)clientRegisterThread_); - } else { - client_->init(); - } - } -} - -void LibkinetoApi::registerClient(ClientInterface* client) { - client_ = client; - if (client && activityProfiler_) { - // Can initialize straight away - client->init(); - } - // Assume here that the external init callback is *not* threadsafe - // and only call it if it's the same thread that called registerClient - clientRegisterThread_ = threadId(); -} - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/src/output_base.h b/plugins/tensorboard-plugins/libkineto/src/output_base.h deleted file mode 100644 index 29d0d57768c91b8593f202cea51071a1affcd88d..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/output_base.h +++ /dev/null @@ -1,104 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include - -#ifdef HAS_CUPTI -#include -#include "CuptiActivity.h" -#endif // HAS_CUPTI -#include "ActivityBuffers.h" -#include "GenericTraceActivity.h" -#include "ThreadUtil.h" -#include "TraceSpan.h" - -namespace KINETO_NAMESPACE { - class Config; - class GpuKernelActivity; - struct RuntimeActivity; -} - -namespace libkineto { - -using namespace KINETO_NAMESPACE; - -class ActivityLogger { - public: - - virtual ~ActivityLogger() = default; - - struct DeviceInfo { - DeviceInfo(int64_t id, const std::string& name, const std::string& label) : - id(id), name(name), label(label) {} - int64_t id; - const std::string name; - const std::string label; - }; - - struct ResourceInfo { - ResourceInfo( - int64_t deviceId, - int64_t id, - int64_t sortIndex, - const std::string& name) : - id(id), sortIndex(sortIndex), deviceId(deviceId), name(name) {} - int64_t id; - int64_t sortIndex; - int64_t deviceId; - const std::string name; - }; - - struct OverheadInfo { - explicit OverheadInfo(const std::string& name) : name(name) {} - const std::string name; - }; - - virtual void handleDeviceInfo( - const DeviceInfo& info, - uint64_t time) = 0; - - virtual void handleResourceInfo(const ResourceInfo& info, int64_t time) = 0; - - virtual void handleOverheadInfo(const OverheadInfo& info, int64_t time) = 0; - - virtual void handleTraceSpan(const TraceSpan& span) = 0; - - virtual void handleActivity( - const libkineto::ITraceActivity& activity) = 0; - virtual void handleGenericActivity( - const libkineto::GenericTraceActivity& activity) = 0; - -#ifdef HAS_CUPTI - virtual void handleGpuActivity( - const GpuActivity& activity) = 0; - virtual void handleGpuActivity( - const GpuActivity& activity) = 0; - virtual void handleGpuActivity( - const GpuActivity& activity) = 0; - virtual void handleGpuActivity( - const GpuActivity& activity) = 0; -#endif // HAS_CUPTI - - virtual void handleTraceStart( - const std::unordered_map& metadata) = 0; - - void handleTraceStart() { - handleTraceStart(std::unordered_map()); - } - - virtual void finalizeTrace( - const KINETO_NAMESPACE::Config& config, - std::unique_ptr buffers, - int64_t endTime, - std::unordered_map>& metadata) = 0; - - protected: - ActivityLogger() = default; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_csv.cpp b/plugins/tensorboard-plugins/libkineto/src/output_csv.cpp deleted file mode 100644 index e56c02293982745ed0c013b83bd04d9f42ea7305..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/output_csv.cpp +++ /dev/null @@ -1,88 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "output_csv.h" - -#include -#include -#include - -#include -#include - -#include "Config.h" -#include "Logger.h" - -namespace KINETO_NAMESPACE { - -static void write_header( - std::ostream& out, - const std::vector& percentiles) { - out << "timestamp,delta_ms,device,event_name"; - for (int p : percentiles) { - out << ",p" << p; - } - out << ",total" << std::endl; -} - -void EventCSVLogger::update(const Config& config) { - eventNames_.clear(); - eventNames_.insert(config.eventNames().begin(), config.eventNames().end()); - eventNames_.insert(config.metricNames().begin(), config.metricNames().end()); - if (config.percentiles() != percentiles_) { - percentiles_ = config.percentiles(); - if (out_) { - write_header(*out_, percentiles_); - } - } -} - -void EventCSVLogger::handleSample(int device, const Sample& sample, bool from_new_version) { - using namespace std::chrono; - if (out_) { - auto now = system_clock::now(); - auto time = system_clock::to_time_t(now); - for (const Stat& s : sample.stats) { - if (eventNames_.find(s.name) == eventNames_.end()) { - continue; - } - *out_ << fmt::format("{:%Y-%m-%d %H:%M:%S}", fmt::localtime(time)) << ","; - *out_ << sample.deltaMsec << ","; - *out_ << device << ","; - *out_ << s.name; - for (const auto& p : s.percentileValues) { - *out_ << "," << p.second; - } - *out_ << "," << s.total << std::endl; - } - } -} - -void EventCSVFileLogger::update(const Config& config) { - if (config.eventLogFile() != filename_) { - if (of_.is_open()) { - of_.close(); - out_ = nullptr; - percentiles_.clear(); - } - filename_ = config.eventLogFile(); - if (!filename_.empty()) { - of_.open(filename_, std::ios::out | std::ios::trunc); - out_ = &of_; - } - } - EventCSVLogger::update(config); -} - -void EventCSVDbgLogger::update(const Config& config) { - if (out_ && config.verboseLogLevel() < 0) { - out_ = nullptr; - } else if (!out_ && config.verboseLogLevel() >= 0) { - out_ = &LIBKINETO_DBG_STREAM; - } - if (config.verboseLogLevel() >= 0) { - percentiles_.clear(); - EventCSVLogger::update(config); - } -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_csv.h b/plugins/tensorboard-plugins/libkineto/src/output_csv.h deleted file mode 100644 index bca29f4db99af8aedf031aed869ff2efd3df6155..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/output_csv.h +++ /dev/null @@ -1,39 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once -#include "SampleListener.h" - -#include -#include -#include - -namespace KINETO_NAMESPACE { - -class EventCSVLogger : public SampleListener { - public: - void update(const Config& config) override; - void handleSample(int device, const Sample& sample, bool from_new_version) override; - - protected: - EventCSVLogger() : out_(nullptr) {} - - std::ostream* out_; - std::set eventNames_; - std::vector percentiles_; -}; - -class EventCSVFileLogger : public EventCSVLogger { - public: - void update(const Config& config) override; - - private: - std::ofstream of_; - std::string filename_; -}; - -class EventCSVDbgLogger : public EventCSVLogger { - public: - void update(const Config& config) override; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_json.cpp b/plugins/tensorboard-plugins/libkineto/src/output_json.cpp deleted file mode 100644 index 0ef22339fad15d6a78e43d7fcb7761fbbc97333b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/output_json.cpp +++ /dev/null @@ -1,583 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "output_json.h" - -#include -#include -#include -#include - -#include "Config.h" -#ifdef HAS_CUPTI -#include "CuptiActivity.h" -#include "CuptiActivity.tpp" -#include "CuptiActivityApi.h" -#include "CudaDeviceProperties.h" -#endif // HAS_CUPTI -#include "Demangle.h" -#include "TraceSpan.h" - -#include "Logger.h" - -using std::endl; -using namespace libkineto; - -namespace KINETO_NAMESPACE { - -static constexpr int kSchemaVersion = 1; -static constexpr char kFlowStart = 's'; -static constexpr char kFlowEnd = 'f'; - -#ifdef __linux__ -static constexpr char kDefaultLogFileFmt[] = - "/tmp/libkineto_activities_{}.json"; -#else -static constexpr char kDefaultLogFileFmt[] = "libkineto_activities_{}.json"; -#endif - -std::string& ChromeTraceLogger::sanitizeStrForJSON(std::string& value) { -// Replace all backslashes with forward slash because Windows paths causing JSONDecodeError. -#ifdef _WIN32 - std::replace(value.begin(), value.end(), '\\', '/'); -#endif - return value; -} - -void ChromeTraceLogger::metadataToJSON( - const std::unordered_map& metadata) { - for (const auto& kv : metadata) { - traceOf_ << fmt::format(R"JSON( - "{}": {},)JSON", kv.first, kv.second); - } -} - -void ChromeTraceLogger::handleTraceStart( - const std::unordered_map& metadata) { - traceOf_ << fmt::format(R"JSON( -{{ - "schemaVersion": {},)JSON", kSchemaVersion); - -#ifdef HAS_CUPTI - traceOf_ << fmt::format(R"JSON( - "deviceProperties": [{} - ],)JSON", devicePropertiesJson()); -#endif - - metadataToJSON(metadata); - traceOf_ << R"JSON( - "traceEvents": [)JSON"; -} - -static std::string defaultFileName() { - return fmt::format(kDefaultLogFileFmt, processId()); -} - -void ChromeTraceLogger::openTraceFile() { - traceOf_.open(fileName_, std::ofstream::out | std::ofstream::trunc); - if (!traceOf_) { - PLOG(ERROR) << "Failed to open '" << fileName_ << "'"; - } else { - LOG(INFO) << "Tracing to " << fileName_; - } -} - -ChromeTraceLogger::ChromeTraceLogger(const std::string& traceFileName) { - fileName_ = traceFileName.empty() ? defaultFileName() : traceFileName; - traceOf_.clear(std::ios_base::badbit); - openTraceFile(); -} - -static int64_t us(int64_t timestamp) { - // It's important that this conversion is the same here and in the CPU trace. - // No rounding! - return timestamp / 1000; -} - -void ChromeTraceLogger::handleDeviceInfo( - const DeviceInfo& info, - uint64_t time) { - if (!traceOf_) { - return; - } - - // M is for metadata - // process_name needs a pid and a name arg - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "name": "process_name", "ph": "M", "ts": {}, "pid": {}, "tid": 0, - "args": {{ - "name": "{}" - }} - }}, - {{ - "name": "process_labels", "ph": "M", "ts": {}, "pid": {}, "tid": 0, - "args": {{ - "labels": "{}" - }} - }}, - {{ - "name": "process_sort_index", "ph": "M", "ts": {}, "pid": {}, "tid": 0, - "args": {{ - "sort_index": {} - }} - }},)JSON", - time, info.id, - info.name, - time, info.id, - info.label, - time, info.id, - info.id < 8 ? info.id + 0x1000000ll : info.id); - // clang-format on -} - -void ChromeTraceLogger::handleResourceInfo( - const ResourceInfo& info, - int64_t time) { - if (!traceOf_) { - return; - } - - // M is for metadata - // thread_name needs a pid and a name arg - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "name": "thread_name", "ph": "M", "ts": {}, "pid": {}, "tid": {}, - "args": {{ - "name": "{}" - }} - }}, - {{ - "name": "thread_sort_index", "ph": "M", "ts": {}, "pid": {}, "tid": {}, - "args": {{ - "sort_index": {} - }} - }},)JSON", - time, info.deviceId, info.id, - info.name, - time, info.deviceId, info.id, - info.sortIndex); - // clang-format on -} - -void ChromeTraceLogger::handleOverheadInfo( - const OverheadInfo& info, - int64_t time) { - if (!traceOf_) { - return; - } - - // TOOD: reserve pid = -1 for overhead but we need to rethink how to scale this for - // other metadata - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "name": "process_name", "ph": "M", "ts": {}, "pid": -1, "tid": 0, - "args": {{ - "name": "{}" - }} - }}, - {{ - "name": "process_sort_index", "ph": "M", "ts": {}, "pid": -1, "tid": 0, - "args": {{ - "sort_index": {} - }} - }},)JSON", - time, - info.name, - time, - 0x100000All); - // clang-format on -} - -void ChromeTraceLogger::handleTraceSpan(const TraceSpan& span) { - if (!traceOf_) { - return; - } - - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Trace", "ts": {}, "dur": {}, - "pid": "Spans", "tid": "{}", - "name": "{}{} ({})", - "args": {{ - "Op count": {} - }} - }}, - {{ - "name": "process_sort_index", "ph": "M", "ts": {}, - "pid": "Spans", "tid": 0, - "args": {{ - "sort_index": {} - }} - }},)JSON", - span.startTime, span.endTime - span.startTime, - span.name, - span.prefix, span.name, span.iteration, - span.opCount, - span.startTime, - // Large sort index to appear at the bottom - 0x20000000ll); - // clang-format on - - addIterationMarker(span); -} - -void ChromeTraceLogger::addIterationMarker(const TraceSpan& span) { - if (!traceOf_) { - return; - } - - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "name": "Iteration Start: {}", "ph": "i", "s": "g", - "pid": "Traces", "tid": "Trace {}", "ts": {} - }},)JSON", - span.name, - span.name, span.startTime); - // clang-format on -} - -static std::string traceActivityJson(const ITraceActivity& activity) { - // clang-format off - int64_t ts = activity.timestamp(); - int64_t duration = activity.duration(); - if (activity.type() == ActivityType::GPU_USER_ANNOTATION) { - // The GPU user annotations start at the same time as the - // first associated GPU activity. Since they appear later - // in the trace file, this causes a visualization issue in Chrome. - // Make it start one us earlier. - ts--; - duration++; // Still need it to end at the orginal point - } - return fmt::format(R"JSON( - "name": "{}", "pid": {}, "tid": {}, - "ts": {}, "dur": {})JSON", - activity.name(), activity.deviceId(), activity.resourceId(), - ts, duration); - // clang-format on -} - -void ChromeTraceLogger::handleGenericInstantEvent( - const libkineto::ITraceActivity& op) { - if (!traceOf_) { - return; - } - - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "i", "s": "t", "name": "{}", - "pid": {}, "tid": {}, - "ts": {}, - "args": {{ - {} - }} - }},)JSON", - op.name(), op.deviceId(), op.resourceId(), - op.timestamp(), op.metadataJson()); -} - -void ChromeTraceLogger::handleActivity( - const libkineto::ITraceActivity& op) { - if (!traceOf_) { - return; - } - - if (op.type() == ActivityType::CPU_INSTANT_EVENT) { - handleGenericInstantEvent(op); - return; - } - - const std::string op_metadata = op.metadataJson(); - std::string separator = ""; - if (op_metadata.find_first_not_of(" \t\n") != std::string::npos) { - separator = ",\n "; - } - std::string span = ""; - if (op.traceSpan()) { - span = fmt::format(R"JSON( - "Trace name": "{}", "Trace iteration": {},)JSON", - op.traceSpan()->name, - op.traceSpan()->iteration); - } - - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "{}", {}, - "args": {{{} - "External id": {}{}{} - }} - }},)JSON", - toString(op.type()), traceActivityJson(op), - // args - span, - op.correlationId(), separator, op_metadata); - // clang-format on - if (op.flowId() > 0) { - handleGenericLink(op); - } -} - -void ChromeTraceLogger::handleGenericActivity( - const libkineto::GenericTraceActivity& op) { - handleActivity(op); -} - -void ChromeTraceLogger::handleGenericLink(const ITraceActivity& act) { - static struct { - int type; - char longName[24]; - char shortName[16]; - } flow_names[] = { - {kLinkFwdBwd, "forward_backward", "fwd_bwd"}, - {kLinkAsyncCpuGpu, "async_cpu_to_gpu", "async_gpu"} - }; - for (auto& flow : flow_names) { - if (act.flowType() == flow.type) { - // Link the activities via flow ID in source and destination. - // The source node must return true from flowStart() - // and the destination node false. - if (act.flowStart()) { - handleLink(kFlowStart, act, act.flowId(), flow.longName, flow.shortName); - } else { - handleLink(kFlowEnd, act, act.flowId(), flow.longName, flow.shortName); - } - return; - } - } - LOG(ERROR) << "Unknown flow type: " << act.flowType(); -} - -void ChromeTraceLogger::handleLink( - char type, - const ITraceActivity& e, - int64_t id, - const std::string& cat, - const std::string& name) { - if (!traceOf_) { - return; - } - - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "{}", "id": {}, "pid": {}, "tid": {}, "ts": {}, - "cat": "{}", "name": "{}", "bp": "e" - }},)JSON", - type, id, e.deviceId(), e.resourceId(), e.timestamp(), cat, name); - // clang-format on -} - -#ifdef HAS_CUPTI -// GPU side kernel activity -void ChromeTraceLogger::handleGpuActivity( - const GpuActivity& activity) { - if (!traceOf_) { - return; - } - const CUpti_ActivityKernel4* kernel = &activity.raw(); - constexpr int threads_per_warp = 32; - float blocks_per_sm = -1.0; - float warps_per_sm = -1.0; - int sm_count = smCount(kernel->deviceId); - if (sm_count) { - blocks_per_sm = - (kernel->gridX * kernel->gridY * kernel->gridZ) / (float) sm_count; - warps_per_sm = - blocks_per_sm * (kernel->blockX * kernel->blockY * kernel->blockZ) - / threads_per_warp; - } - - // Calculate occupancy - float occupancy = KINETO_NAMESPACE::kernelOccupancy( - kernel->deviceId, - kernel->registersPerThread, - kernel->staticSharedMemory, - kernel->dynamicSharedMemory, - kernel->blockX, - kernel->blockY, - kernel->blockZ, - blocks_per_sm); - - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Kernel", {}, - "args": {{ - "queued": {}, "device": {}, "context": {}, - "stream": {}, "correlation": {}, - "registers per thread": {}, - "shared memory": {}, - "blocks per SM": {}, - "warps per SM": {}, - "grid": [{}, {}, {}], - "block": [{}, {}, {}], - "est. achieved occupancy %": {} - }} - }},)JSON", - traceActivityJson(activity), - // args - us(kernel->queued), kernel->deviceId, kernel->contextId, - kernel->streamId, kernel->correlationId, - kernel->registersPerThread, - kernel->staticSharedMemory + kernel->dynamicSharedMemory, - blocks_per_sm, - warps_per_sm, - kernel->gridX, kernel->gridY, kernel->gridZ, - kernel->blockX, kernel->blockY, kernel->blockZ, - (int) (0.5 + occupancy * 100.0)); - // clang-format on - - auto to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); -} - -static std::string bandwidth(uint64_t bytes, uint64_t duration) { - return duration == 0 ? "\"N/A\"" : fmt::format("{}", bytes * 1.0 / duration); -} - -// GPU side memcpy activity -void ChromeTraceLogger::handleGpuActivity( - const GpuActivity& activity) { - if (!traceOf_) { - return; - } - const CUpti_ActivityMemcpy& memcpy = activity.raw(); - VLOG(2) << memcpy.correlationId << ": MEMCPY"; - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Memcpy", {}, - "args": {{ - "device": {}, "context": {}, - "stream": {}, "correlation": {}, - "bytes": {}, "memory bandwidth (GB/s)": {} - }} - }},)JSON", - traceActivityJson(activity), - // args - memcpy.deviceId, memcpy.contextId, - memcpy.streamId, memcpy.correlationId, - memcpy.bytes, bandwidth(memcpy.bytes, memcpy.end - memcpy.start)); - // clang-format on - - int64_t to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); -} - -// GPU side memcpy activity -void ChromeTraceLogger::handleGpuActivity( - const GpuActivity& activity) { - if (!traceOf_) { - return; - } - const CUpti_ActivityMemcpy2& memcpy = activity.raw(); - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Memcpy", {}, - "args": {{ - "fromDevice": {}, "inDevice": {}, "toDevice": {}, - "fromContext": {}, "inContext": {}, "toContext": {}, - "stream": {}, "correlation": {}, - "bytes": {}, "memory bandwidth (GB/s)": {} - }} - }},)JSON", - traceActivityJson(activity), - // args - memcpy.srcDeviceId, memcpy.deviceId, memcpy.dstDeviceId, - memcpy.srcContextId, memcpy.contextId, memcpy.dstContextId, - memcpy.streamId, memcpy.correlationId, - memcpy.bytes, bandwidth(memcpy.bytes, memcpy.end - memcpy.start)); - // clang-format on - - int64_t to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); -} - -void ChromeTraceLogger::handleGpuActivity( - const GpuActivity& activity) { - if (!traceOf_) { - return; - } - const CUpti_ActivityMemset& memset = activity.raw(); - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "ph": "X", "cat": "Memset", {}, - "args": {{ - "device": {}, "context": {}, - "stream": {}, "correlation": {}, - "bytes": {}, "memory bandwidth (GB/s)": {} - }} - }},)JSON", - traceActivityJson(activity), - // args - memset.deviceId, memset.contextId, - memset.streamId, memset.correlationId, - memset.bytes, bandwidth(memset.bytes, memset.end - memset.start)); - // clang-format on - - int64_t to_id = activity.correlationId(); - handleLink(kFlowEnd, activity, to_id, "async_cpu_to_gpu", "async_gpu"); -} -#endif // HAS_CUPTI - -void ChromeTraceLogger::finalizeTrace( - const Config& /*unused*/, - std::unique_ptr /*unused*/, - int64_t endTime, - std::unordered_map>& metadata) { - if (!traceOf_) { - LOG(ERROR) << "Failed to write to log file!"; - return; - } - LOG(INFO) << "Chrome Trace written to " << fileName_; - // clang-format off - traceOf_ << fmt::format(R"JSON( - {{ - "name": "Record Window End", "ph": "i", "s": "g", - "pid": "", "tid": "", "ts": {} - }} - ],)JSON", - endTime); - -#if !USE_GOOGLE_LOG - std::unordered_map PreparedMetadata; - for (const auto& kv : metadata) { - // Skip empty log buckets, ex. skip ERROR if its empty. - if (!kv.second.empty()) { - std::string value = "["; - // Ex. Each metadata from logger is a list of strings, expressed in JSON as - // "ERROR": ["Error 1", "Error 2"], - // "WARNING": ["Warning 1", "Warning 2", "Warning 3"], - // ... - int mdv_count = kv.second.size(); - for (const auto& v : kv.second) { - value.append("\"" + v + "\""); - if(mdv_count > 1) { - value.append(","); - mdv_count--; - } - } - value.append("]"); - PreparedMetadata[kv.first] = sanitizeStrForJSON(value); - } - } - metadataToJSON(PreparedMetadata); -#endif // !USE_GOOGLE_LOG - - // Putting this here because the last entry MUST not end with a comma. - traceOf_ << fmt::format(R"JSON( - "traceName": "{}" -}})JSON", sanitizeStrForJSON(fileName_)); - // clang-format on - - traceOf_.close(); -} - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_json.h b/plugins/tensorboard-plugins/libkineto/src/output_json.h deleted file mode 100644 index 5a8a81e4a9fdeef09b0e9ace59b964d5ab99b7ad..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/output_json.h +++ /dev/null @@ -1,91 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include -#include - -#ifdef HAS_CUPTI -#include -#endif -#include "GenericTraceActivity.h" -#include "output_base.h" - -namespace KINETO_NAMESPACE { - // Previous declaration of TraceSpan is struct. Must match the same here. - struct TraceSpan; -} - -namespace KINETO_NAMESPACE { - -class Config; - -class ChromeTraceLogger : public libkineto::ActivityLogger { - public: - explicit ChromeTraceLogger(const std::string& traceFileName); - - // Note: the caller of these functions should handle concurrency - // i.e., we these functions are not thread-safe - void handleDeviceInfo( - const DeviceInfo& info, - uint64_t time) override; - - void handleOverheadInfo(const OverheadInfo& info, int64_t time) override; - - void handleResourceInfo(const ResourceInfo& info, int64_t time) override; - - void handleTraceSpan(const TraceSpan& span) override; - - void handleActivity(const ITraceActivity& activity) override; - void handleGenericActivity(const GenericTraceActivity& activity) override; - -#ifdef HAS_CUPTI - void handleGpuActivity(const GpuActivity& activity) override; - void handleGpuActivity(const GpuActivity& activity) override; - void handleGpuActivity(const GpuActivity& activity) override; - void handleGpuActivity(const GpuActivity& activity) override; -#endif // HAS_CUPTI - - void handleTraceStart( - const std::unordered_map& metadata) override; - - void finalizeTrace( - const Config& config, - std::unique_ptr buffers, - int64_t endTime, - std::unordered_map>& metadata) override; - - std::string traceFileName() const { - return fileName_; - } - - private: - - // Create a flow event (arrow) - void handleLink( - char type, - const ITraceActivity& e, - int64_t id, - const std::string& cat, - const std::string& name); - - void addIterationMarker(const TraceSpan& span); - - void openTraceFile(); - - void handleGenericInstantEvent(const ITraceActivity& op); - - void handleGenericLink(const ITraceActivity& activity); - - void metadataToJSON(const std::unordered_map& metadata); - - std::string& sanitizeStrForJSON(std::string& value); - - std::string fileName_; - std::ofstream traceOf_; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/src/output_membuf.h b/plugins/tensorboard-plugins/libkineto/src/output_membuf.h deleted file mode 100644 index ef6aadeb65728e0e05e454f98b32ccecca229cf4..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/src/output_membuf.h +++ /dev/null @@ -1,130 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include -#include - -#ifdef HAS_CUPTI -#include -#endif - -#include "Config.h" -#include "GenericTraceActivity.h" -#ifdef HAS_CUPTI -#include "CuptiActivity.h" -#include "CuptiActivity.tpp" -#endif // HAS_CUPTI -#include "output_base.h" - -namespace KINETO_NAMESPACE { - -class Config; - -class MemoryTraceLogger : public ActivityLogger { - public: - MemoryTraceLogger(const Config& config) : config_(config.clone()) { - activities_.reserve(100000); - } - - // Note: the caller of these functions should handle concurrency - // i.e., these functions are not thread-safe - void handleDeviceInfo( - const DeviceInfo& info, - uint64_t time) override { - deviceInfoList_.emplace_back(info, time); - } - - void handleResourceInfo(const ResourceInfo& info, int64_t time) override { - resourceInfoList_.emplace_back(info, time); - } - - void handleOverheadInfo(const OverheadInfo& info, int64_t time) override {} - - void handleTraceSpan(const TraceSpan& span) override { - // Handled separately - } - - template - void addActivityWrapper(const T& act) { - wrappers_.push_back(std::make_unique(act)); - activities_.push_back(wrappers_.back().get()); - } - - // Just add the pointer to the list - ownership of the underlying - // objects must be transferred in ActivityBuffers via finalizeTrace - void handleActivity(const ITraceActivity& activity) override { - activities_.push_back(&activity); - } - void handleGenericActivity(const GenericTraceActivity& activity) override { - addActivityWrapper(activity); - } - -#ifdef HAS_CUPTI - void handleGpuActivity(const GpuActivity& activity) override { - addActivityWrapper(activity); - } - void handleGpuActivity(const GpuActivity& activity) override { - addActivityWrapper(activity); - } - void handleGpuActivity(const GpuActivity& activity) override { - addActivityWrapper(activity); - } - void handleGpuActivity(const GpuActivity& activity) override { - addActivityWrapper(activity); - } -#endif // HAS_CUPTI - - void handleTraceStart( - const std::unordered_map& metadata) override { - metadata_ = metadata; - } - - void finalizeTrace( - const Config& config, - std::unique_ptr buffers, - int64_t endTime, - std::unordered_map>& metadata) override { - buffers_ = std::move(buffers); - endTime_ = endTime; - } - - const std::vector* traceActivities() { - return &activities_; - } - - void log(ActivityLogger& logger) { - logger.handleTraceStart(metadata_); - for (auto& activity : activities_) { - activity->log(logger); - } - for (auto& p : deviceInfoList_) { - logger.handleDeviceInfo(p.first, p.second); - } - for (auto& p : resourceInfoList_) { - logger.handleResourceInfo(p.first, p.second); - } - for (auto& cpu_trace_buffer : buffers_->cpu) { - logger.handleTraceSpan(cpu_trace_buffer->span); - } - // Hold on to the buffers - logger.finalizeTrace(*config_, nullptr, endTime_, loggerMetadata_); - } - - private: - - std::unique_ptr config_; - // Optimization: Remove unique_ptr by keeping separate vector per type - std::vector activities_; - std::vector> wrappers_; - std::vector> deviceInfoList_; - std::vector> resourceInfoList_; - std::unique_ptr buffers_; - std::unordered_map metadata_; - std::unordered_map> loggerMetadata_; - int64_t endTime_{0}; -}; - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt b/plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt deleted file mode 100644 index ca54460b36cd4ade93918c8512f1309b48552e65..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -cmake_minimum_required(VERSION 3.5 FATAL_ERROR) - -# TODO diff --git a/plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp b/plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp deleted file mode 100644 index 16bc86e751cefdbee1d48aeb79fc849b7d151a18..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp +++ /dev/null @@ -1,315 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "include/Config.h" - -#include -#include -#include -#include - -using namespace std::chrono; -using namespace KINETO_NAMESPACE; - -TEST(ParseTest, Whitespace) { - Config cfg; - // Check that various types of whitespace is ignored - EXPECT_TRUE(cfg.parse("")); - EXPECT_TRUE(cfg.parse(" ")); - EXPECT_TRUE(cfg.parse("\t")); - EXPECT_TRUE(cfg.parse("\n")); - EXPECT_TRUE(cfg.parse(" ")); - EXPECT_TRUE(cfg.parse("\t \n \t\t\n\n")); - // Only the above characters are supported - EXPECT_FALSE(cfg.parse("\r\n")); -} - -TEST(ParseTest, Comment) { - Config cfg; - // Anything following a '#' should be ignored, up to a newline - EXPECT_TRUE(cfg.parse("# comment")); - EXPECT_TRUE(cfg.parse(" # ~!@#$")); - EXPECT_TRUE(cfg.parse("\t#abc")); - EXPECT_TRUE(cfg.parse("###\n##")); - EXPECT_TRUE(cfg.parse("EVENTS=util ##ok")); - EXPECT_TRUE(cfg.parse("EVENTS=util ## EVENTS=instruction")); - // Whatever appears before the comment must be valid format - EXPECT_FALSE(cfg.parse("util ## not ok")); - EXPECT_FALSE(cfg.parse("## ok \n blah # not OK")); - // Check that a comment does not affect config parsing - EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS = 1 # Sample every millisecond")); - EXPECT_EQ(cfg.samplePeriod(), milliseconds(1)); -} - -TEST(ParseTest, Format) { - Config cfg; - // The basic format is just "name = value". - // Where both value and name can be almost anything. - // Leading and trailing whitespace should be removed - // for both 'name' and 'value', but internal whitespace is not. - EXPECT_FALSE(cfg.parse("events")); - EXPECT_TRUE(cfg.parse("events=")); - EXPECT_FALSE(cfg.parse("=events=")); - EXPECT_TRUE(cfg.parse("events=1,2,3")); - // Only one setting per line - EXPECT_FALSE(cfg.parse("events = 1,2,3 ; metrics = 4,5,6")); - // Names are case sensitive - EXPECT_TRUE(cfg.parse("EVENTS = 1,2,3 \n metrics = 4,5,6")); - EXPECT_EQ(cfg.eventNames(), std::set({"1", "2", "3"})); - EXPECT_EQ(cfg.metricNames().size(), 0); - // Leading and trailing whitespace removed for event and metric names, - // but not internal. - EXPECT_TRUE( - cfg.parse("EVENTS = 1, 2, 3 \n \tMETRICS\t = \t4,\t5\t,\ts i x ")); - EXPECT_EQ(cfg.eventNames(), std::set({"1", "2", "3"})); - EXPECT_EQ(cfg.metricNames(), std::set({"4", "5", "s i x"})); -} - -TEST(ParseTest, DefaultActivityTypes) { - Config cfg; - cfg.validate(std::chrono::system_clock::now()); - auto all_activities = activityTypes(); - // TODO: introduce optional activities - EXPECT_EQ(cfg.selectedActivityTypes(), - std::set(all_activities.begin(), all_activities.end() - 1)); -} - -TEST(ParseTest, ActivityTypes) { - Config cfg; - EXPECT_FALSE(cfg.parse("ACTIVITY_TYPES")); - EXPECT_TRUE(cfg.parse("ACTIVITY_TYPES=")); - EXPECT_FALSE(cfg.parse("=ACTIVITY_TYPES=")); - - EXPECT_EQ(cfg.selectedActivityTypes(), - std::set({ActivityType::CPU_OP, - ActivityType::CPU_INSTANT_EVENT, - ActivityType::PYTHON_FUNCTION, - ActivityType::USER_ANNOTATION, - ActivityType::GPU_USER_ANNOTATION, - ActivityType::GPU_MEMCPY, - ActivityType::GPU_MEMSET, - ActivityType::CONCURRENT_KERNEL, - ActivityType::EXTERNAL_CORRELATION, - ActivityType::GLOW_RUNTIME, - ActivityType::CUDA_RUNTIME, - ActivityType::CUDA_PROFILER_RANGE})); - - Config cfg2; - EXPECT_TRUE(cfg2.parse("ACTIVITY_TYPES=gpu_memcpy,gpu_MeMsEt,kernel")); - EXPECT_EQ(cfg2.selectedActivityTypes(), - std::set({ActivityType::GPU_MEMCPY, - ActivityType::GPU_MEMSET, - ActivityType::CONCURRENT_KERNEL})); - - EXPECT_TRUE(cfg2.parse("ACTIVITY_TYPES = cuda_Runtime,")); - EXPECT_EQ(cfg2.selectedActivityTypes(), - std::set({ActivityType::CUDA_RUNTIME})); - - // Should throw an exception because incorrect activity name - EXPECT_FALSE(cfg2.parse("ACTIVITY_TYPES = memcopy,cuda_runtime")); - - EXPECT_TRUE(cfg2.parse("ACTIVITY_TYPES = cpu_op")); - EXPECT_EQ(cfg2.selectedActivityTypes(), - std::set({ActivityType::CPU_OP})); -} - -TEST(ParseTest, SamplePeriod) { - Config cfg; - EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS=10")); - EXPECT_EQ(cfg.samplePeriod(), milliseconds(10)); - EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS=0")); - cfg.validate(std::chrono::system_clock::now()); - // 0 should be adjustd up to 1 - EXPECT_EQ(cfg.samplePeriod(), milliseconds(1)); - // Negative and non-int values should fail - EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=-10")); - EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=1.5")); - EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=")); - EXPECT_FALSE(cfg.parse("SAMPLE_PERIOD_MSECS=string")); - EXPECT_EQ(cfg.samplePeriod(), milliseconds(1)); -} - -TEST(ParseTest, MultiplexPeriod) { - Config cfg; - auto now = std::chrono::system_clock::now(); - - EXPECT_TRUE(cfg.parse("SAMPLE_PERIOD_MSECS=100\nMULTIPLEX_PERIOD_MSECS=100")); - EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(100)); - EXPECT_TRUE(cfg.parse("MULTIPLEX_PERIOD_MSECS = 0")); - cfg.validate(now); - // Adjusted to match sample period - EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(100)); - EXPECT_TRUE(cfg.parse("MULTIPLEX_PERIOD_MSECS \t= \t 750 \n")); - cfg.validate(now); - // Adjusted to match multiple of sample period - EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(800)); - EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=-10")); - EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=1.5")); - EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=")); - EXPECT_FALSE(cfg.parse("MULTIPLEX_PERIOD_MSECS=string")); - // Previous value not affected - EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(800)); -} - -TEST(ParseTest, ReportPeriod) { - Config cfg; - EXPECT_TRUE(cfg.parse("REPORT_PERIOD_SECS=1")); - EXPECT_EQ(cfg.reportPeriod(), seconds(1)); - // Whitespace - EXPECT_TRUE(cfg.parse("REPORT_PERIOD_SECS = \t100")); - EXPECT_EQ(cfg.reportPeriod(), seconds(100)); - // Invalid types - EXPECT_FALSE(cfg.parse("REPORT_PERIOD_SECS=-1")); - EXPECT_EQ(cfg.reportPeriod(), seconds(100)); -} - -TEST(ParseTest, SamplesPerReport) { - Config cfg; - auto now = std::chrono::system_clock::now(); - - EXPECT_TRUE(cfg.parse(R"( - SAMPLE_PERIOD_MSECS = 1000 - REPORT_PERIOD_SECS = 1 - SAMPLES_PER_REPORT = 10)")); - cfg.validate(now); - // Adjusted down to one sample per report - EXPECT_EQ(cfg.samplesPerReport(), 1); - EXPECT_TRUE(cfg.parse(R"( - SAMPLE_PERIOD_MSECS = 1000 - REPORT_PERIOD_SECS = 10 - SAMPLES_PER_REPORT = 10)")); - cfg.validate(now); - // No adjustment needed - EXPECT_EQ(cfg.samplesPerReport(), 10); - EXPECT_TRUE(cfg.parse(R"( - SAMPLE_PERIOD_MSECS = 1000 - REPORT_PERIOD_SECS = 2 - SAMPLES_PER_REPORT = 10)")); - cfg.validate(now); - // Adjusted to 2 samples per report - EXPECT_EQ(cfg.samplesPerReport(), 2); - EXPECT_TRUE(cfg.parse(R"( - SAMPLE_PERIOD_MSECS = 200 - REPORT_PERIOD_SECS = 2 - SAMPLES_PER_REPORT = 10)")); - cfg.validate(now); - // No adjustment needed - EXPECT_EQ(cfg.samplesPerReport(), 10); - EXPECT_TRUE(cfg.parse("SAMPLES_PER_REPORT=0")); - cfg.validate(now); - // Adjusted up to 1 - EXPECT_EQ(cfg.samplesPerReport(), 1); - // Invalid value types - EXPECT_FALSE(cfg.parse("SAMPLES_PER_REPORT=-10")); - EXPECT_FALSE(cfg.parse("SAMPLES_PER_REPORT=1.5")); - EXPECT_EQ(cfg.samplesPerReport(), 1); - - EXPECT_TRUE(cfg.parse(R"( - SAMPLE_PERIOD_MSECS=1000 - MULTIPLEX_PERIOD_MSECS=500 # Must be a multiple of sample period - REPORT_PERIOD_SECS=0 # Must be non-zero multiple of multiplex period - SAMPLES_PER_REPORT=5 # Max report period / multiplex period)")); - cfg.validate(now); - // Multiple adjustments - EXPECT_EQ(cfg.samplePeriod(), milliseconds(1000)); - EXPECT_EQ(cfg.multiplexPeriod(), milliseconds(1000)); - EXPECT_EQ(cfg.reportPeriod(), seconds(1)); - EXPECT_EQ(cfg.samplesPerReport(), 1); -} - -TEST(ParseTest, EnableSigUsr2) { - Config cfg; - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=yes")); - EXPECT_TRUE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=no")); - EXPECT_FALSE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=YES")); - EXPECT_TRUE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=NO")); - EXPECT_FALSE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=Y")); - EXPECT_TRUE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=N")); - EXPECT_FALSE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=T")); - EXPECT_TRUE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=F")); - EXPECT_FALSE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=true")); - EXPECT_TRUE(cfg.sigUsr2Enabled()); - EXPECT_TRUE(cfg.parse("ENABLE_SIGUSR2=false")); - EXPECT_FALSE(cfg.sigUsr2Enabled()); - EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2= ")); - EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2=2")); - EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2=-1")); - EXPECT_FALSE(cfg.parse("ENABLE_SIGUSR2=yep")); -} - -TEST(ParseTest, DeviceMask) { - Config cfg; - // Single device - EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 0")); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(0)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(1)); - - // Two devices, internal whitespace - EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 1, 2")); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(0)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(1)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(2)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(3)); - - // Three devices, check that previous devices are ignored - EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 0, 2,4")); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(0)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(1)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(2)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(3)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(4)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(5)); - - // Repeated numbers have no effect - EXPECT_TRUE(cfg.parse("EVENTS_ENABLED_DEVICES = 0,1,1,1,2,3,2,1,3,7,7,3")); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(0)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(1)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(2)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(3)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(4)); - EXPECT_FALSE(cfg.eventProfilerEnabledForDevice(6)); - EXPECT_TRUE(cfg.eventProfilerEnabledForDevice(7)); - - // 8 is larger than the max allowed - EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 3,8")); - - // 300 cannot be held in an uint8_t - EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 300")); - - // Various illegal cases - EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 0,1,two,three")); - EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 0,1,,2")); - EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = -1")); - EXPECT_FALSE(cfg.parse("EVENTS_ENABLED_DEVICES = 1.0")); -} - -TEST(ParseTest, RequestTime) { - Config cfg; - system_clock::time_point now = system_clock::now(); - int64_t tgood_ms = - duration_cast(now.time_since_epoch()).count(); - EXPECT_TRUE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tgood_ms))); - - tgood_ms = duration_cast((now - seconds(5)).time_since_epoch()) - .count(); - EXPECT_TRUE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tgood_ms))); - - int64_t tbad_ms = - duration_cast((now - seconds(20)).time_since_epoch()) - .count(); - EXPECT_FALSE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tbad_ms))); - - EXPECT_FALSE(cfg.parse("REQUEST_TIMESTAMP = 0")); - EXPECT_FALSE(cfg.parse("REQUEST_TIMESTAMP = -1")); - - tbad_ms = duration_cast((now + seconds(10)).time_since_epoch()) - .count(); - EXPECT_FALSE(cfg.parse(fmt::format("REQUEST_TIMESTAMP = {}", tbad_ms))); -} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp deleted file mode 100644 index 6e67980ee31a3386580974033201b7acae75d22b..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp +++ /dev/null @@ -1,629 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include -#include -#include -#include -#include - -#ifdef __linux__ -#include -#include -#include -#endif - -#include "include/libkineto.h" -#include "include/Config.h" -#include "src/CuptiActivityProfiler.h" -#include "src/ActivityTrace.h" -#include "src/CuptiActivityApi.h" -#include "src/output_base.h" -#include "src/output_json.h" -#include "src/output_membuf.h" - -#include "src/Logger.h" -#include "test/MockActivitySubProfiler.h" - -using namespace std::chrono; -using namespace KINETO_NAMESPACE; - -#define CUDA_LAUNCH_KERNEL CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000 -#define CUDA_MEMCPY CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020 - -namespace { -const TraceSpan& defaultTraceSpan() { - static TraceSpan span(0, 0, "Unknown", ""); - return span; -} -} - -// Provides ability to easily create a few test CPU-side ops -struct MockCpuActivityBuffer : public CpuTraceBuffer { - MockCpuActivityBuffer(int64_t startTime, int64_t endTime) { - span = TraceSpan(startTime, endTime,"Test trace"); - gpuOpCount = 0; - } - - void addOp(std::string name, int64_t startTime, int64_t endTime, int64_t correlation) { - GenericTraceActivity op(span, ActivityType::CPU_OP, name); - op.startTime = startTime; - op.endTime = endTime; - op.resource = systemThreadId(); - op.id = correlation; - activities.push_back(std::move(op)); - span.opCount++; - } -}; - -// Provides ability to easily create a few test CUPTI ops -struct MockCuptiActivityBuffer { - void addCorrelationActivity(int64_t correlation, CUpti_ExternalCorrelationKind externalKind, int64_t externalId) { - auto& act = *(CUpti_ActivityExternalCorrelation*) malloc(sizeof(CUpti_ActivityExternalCorrelation)); - act.kind = CUPTI_ACTIVITY_KIND_EXTERNAL_CORRELATION; - act.externalId = externalId; - act.externalKind = externalKind; - act.correlationId = correlation; - activities.push_back(reinterpret_cast(&act)); - } - - void addRuntimeActivity( - CUpti_runtime_api_trace_cbid_enum cbid, - int64_t start_us, int64_t end_us, int64_t correlation) { - auto& act = createActivity( - start_us, end_us, correlation); - act.kind = CUPTI_ACTIVITY_KIND_RUNTIME; - act.cbid = cbid; - act.threadId = threadId(); - activities.push_back(reinterpret_cast(&act)); - } - - void addKernelActivity( - int64_t start_us, int64_t end_us, int64_t correlation) { - auto& act = createActivity( - start_us, end_us, correlation); - act.kind = CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL; - act.deviceId = 0; - act.streamId = 1; - act.name = "kernel"; - act.gridX = act.gridY = act.gridZ = 1; - act.blockX = act.blockY = act.blockZ = 1; - activities.push_back(reinterpret_cast(&act)); - } - - void addMemcpyActivity( - int64_t start_us, int64_t end_us, int64_t correlation) { - auto& act = createActivity( - start_us, end_us, correlation); - act.kind = CUPTI_ACTIVITY_KIND_MEMCPY; - act.deviceId = 0; - act.streamId = 2; - act.copyKind = CUPTI_ACTIVITY_MEMCPY_KIND_HTOD; - act.srcKind = CUPTI_ACTIVITY_MEMORY_KIND_PINNED; - act.dstKind = CUPTI_ACTIVITY_MEMORY_KIND_DEVICE; - activities.push_back(reinterpret_cast(&act)); - } - - template - T& createActivity( - int64_t start_us, int64_t end_us, int64_t correlation) { - T& act = *static_cast(malloc(sizeof(T))); - bzero(&act, sizeof(act)); - act.start = start_us * 1000; - act.end = end_us * 1000; - act.correlationId = correlation; - return act; - } - - ~MockCuptiActivityBuffer() { - for (CUpti_Activity* act : activities) { - free(act); - } - } - - std::vector activities; -}; - -// Mock parts of the CuptiActivityApi -class MockCuptiActivities : public CuptiActivityApi { - public: - virtual int smCount() override { - return 10; - } - - virtual const std::pair processActivities( - CuptiActivityBufferMap&, /*unused*/ - std::function handler) override { - for (CUpti_Activity* act : activityBuffer->activities) { - handler(act); - } - return {activityBuffer->activities.size(), 100}; - } - - virtual std::unique_ptr - activityBuffers() override { - auto map = std::make_unique(); - auto buf = std::make_unique(100); - uint8_t* addr = buf->data(); - (*map)[addr] = std::move(buf); - return map; - } - - void bufferRequestedOverride(uint8_t** buffer, size_t* size, size_t* maxNumRecords) { - this->bufferRequested(buffer, size, maxNumRecords); - } - - std::unique_ptr activityBuffer; -}; - - -// Common setup / teardown and helper functions -class CuptiActivityProfilerTest : public ::testing::Test { - protected: - void SetUp() override { - profiler_ = std::make_unique( - cuptiActivities_, /*cpu only*/ false); - cfg_ = std::make_unique(); - cfg_->validate(std::chrono::system_clock::now()); - loggerFactory.addProtocol("file", [](const std::string& url) { - return std::unique_ptr(new ChromeTraceLogger(url)); - }); - } - - std::unique_ptr cfg_; - MockCuptiActivities cuptiActivities_; - std::unique_ptr profiler_; - ActivityLoggerFactory loggerFactory; -}; - -void checkTracefile(const char* filename) { -#ifdef __linux__ - // Check that the expected file was written and that it has some content - int fd = open(filename, O_RDONLY); - if (!fd) { - perror(filename); - } - EXPECT_TRUE(fd); - // Should expect at least 100 bytes - struct stat buf{}; - fstat(fd, &buf); - EXPECT_GT(buf.st_size, 100); - close(fd); -#endif -} - -TEST(CuptiActivityProfiler, AsyncTrace) { - std::vector log_modules( - {"CuptiActivityProfiler.cpp", "output_json.cpp"}); - SET_LOG_VERBOSITY_LEVEL(1, log_modules); - - MockCuptiActivities activities; - CuptiActivityProfiler profiler(activities, /*cpu only*/ true); - - char filename[] = "/tmp/libkineto_testXXXXXX.json"; - mkstemps(filename, 5); - - Config cfg; - - int iter = 0; - int warmup = 5; - auto now = system_clock::now(); - auto startTime = now + seconds(10); - - bool success = cfg.parse(fmt::format(R"CFG( - ACTIVITIES_WARMUP_PERIOD_SECS = {} - ACTIVITIES_DURATION_SECS = 1 - ACTIVITIES_LOG_FILE = {} - PROFILE_START_TIME = {} - )CFG", warmup, filename, duration_cast(startTime.time_since_epoch()).count())); - - EXPECT_TRUE(success); - EXPECT_FALSE(profiler.isActive()); - - auto logger = std::make_unique(cfg.activitiesLogFile()); - - // Usually configuration is done when now is startTime - warmup to kick off warmup - // but start right away in the test - profiler.configure(cfg, now); - profiler.setLogger(logger.get()); - - EXPECT_TRUE(profiler.isActive()); - - // fast forward in time and we have reached the startTime - now = startTime; - - // Run the profiler - // Warmup - // performRunLoopStep is usually called by the controller loop and takes - // the current time and the controller's next wakeup time. - profiler.performRunLoopStep( - /* Current time */ now, /* Next wakeup time */ now); - - auto next = now + milliseconds(1000); - - // performRunLoopStep can also be called by an application thread to update iteration count - // since this config does not use iteration this should have no effect on the state - while (++iter < 20) { - profiler.performRunLoopStep(now, now, iter); - } - - // Runloop should now be in collect state, so start workload - // Perform another runloop step, passing in the end profile time as current. - // This should terminate collection - profiler.performRunLoopStep( - /* Current time */ next, /* Next wakeup time */ next); - // One step needed for each of the Process and Finalize phases - // Doesn't really matter what times we pass in here. - - EXPECT_TRUE(profiler.isActive()); - - auto nextnext = next + milliseconds(1000); - - while (++iter < 40) { - profiler.performRunLoopStep(next, next, iter); - } - - EXPECT_TRUE(profiler.isActive()); - - profiler.performRunLoopStep(nextnext,nextnext); - profiler.performRunLoopStep(nextnext,nextnext); - - // Assert that tracing has completed - EXPECT_FALSE(profiler.isActive()); - - checkTracefile(filename); -} - -TEST(CuptiActivityProfiler, AsyncTraceUsingIter) { - std::vector log_modules( - {"CuptiActivityProfiler.cpp", "output_json.cpp"}); - SET_LOG_VERBOSITY_LEVEL(1, log_modules); - - auto runIterTest = [&]( - int start_iter, int warmup_iters, int trace_iters) { - - LOG(INFO ) << "Async Trace Test: start_iteration = " << start_iter - << " warmup iterations = " << warmup_iters - << " trace iterations = " << trace_iters; - - MockCuptiActivities activities; - CuptiActivityProfiler profiler(activities, /*cpu only*/ true); - - char filename[] = "/tmp/libkineto_testXXXXXX.json"; - mkstemps(filename, 5); - - Config cfg; - - int iter = 0; - auto now = system_clock::now(); - - bool success = cfg.parse(fmt::format(R"CFG( - PROFILE_START_ITERATION = {} - ACTIVITIES_WARMUP_ITERATIONS={} - ACTIVITIES_ITERATIONS={} - ACTIVITIES_DURATION_SECS = 1 - ACTIVITIES_LOG_FILE = {} - )CFG", start_iter, warmup_iters, trace_iters, filename)); - - EXPECT_TRUE(success); - EXPECT_FALSE(profiler.isActive()); - - auto logger = std::make_unique(cfg.activitiesLogFile()); - - // Usually configuration is done when now is startIter - warmup iter to kick off warmup - // but start right away in the test - while (iter < (start_iter - warmup_iters)) { - profiler.performRunLoopStep(now, now, iter++); - } - - profiler.configure(cfg, now); - profiler.setLogger(logger.get()); - - EXPECT_TRUE(profiler.isActive()); - - // fast forward in time, mimicking what will happen in reality - now += seconds(10); - auto next = now + milliseconds(1000); - - // this call to runloop step should not be effecting the state - profiler.performRunLoopStep(now, next); - EXPECT_TRUE(profiler.isActive()); - - // start trace collection - while (iter < start_iter) { - profiler.performRunLoopStep(now, next, iter++); - } - - // Runloop should now be in collect state, so start workload - - while (iter < (start_iter + trace_iters)) { - profiler.performRunLoopStep(now, next, iter++); - } - - // One step is required for each of the Process and Finalize phases - // Doesn't really matter what times we pass in here. - if (iter >= (start_iter + trace_iters)) { - profiler.performRunLoopStep(now, next, iter++); - } - EXPECT_TRUE(profiler.isActive()); - - auto nextnext = next + milliseconds(1000); - - profiler.performRunLoopStep(nextnext, nextnext); - profiler.performRunLoopStep(nextnext, nextnext); - - // Assert that tracing has completed - EXPECT_FALSE(profiler.isActive()); - - checkTracefile(filename); - }; - - // start iter = 50, warmup iters = 5, trace iters = 10 - runIterTest(50, 5, 10); - // should be able to start at 0 iteration - runIterTest(0, 0, 2); - runIterTest(0, 5, 5); -} - -TEST_F(CuptiActivityProfilerTest, SyncTrace) { - using ::testing::Return; - using ::testing::ByMove; - - // Verbose logging is useful for debugging - std::vector log_modules( - {"CuptiActivityProfiler.cpp"}); - SET_LOG_VERBOSITY_LEVEL(2, log_modules); - - // Start and stop profiling - CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); - int64_t start_time_us = 100; - int64_t duration_us = 300; - auto start_time = time_point(microseconds(start_time_us)); - profiler.configure(*cfg_, start_time); - profiler.startTrace(start_time); - profiler.stopTrace(start_time + microseconds(duration_us)); - - profiler.recordThreadInfo(); - - // Log some cpu ops - auto cpuOps = std::make_unique( - start_time_us, start_time_us + duration_us); - cpuOps->addOp("op1", 120, 150, 1); - cpuOps->addOp("op2", 130, 140, 2); - cpuOps->addOp("op3", 200, 250, 3); - profiler.transferCpuTrace(std::move(cpuOps)); - - // And some GPU ops - auto gpuOps = std::make_unique(); - gpuOps->addRuntimeActivity(CUDA_LAUNCH_KERNEL, 133, 138, 1); - gpuOps->addRuntimeActivity(CUDA_MEMCPY, 210, 220, 2); - gpuOps->addRuntimeActivity(CUDA_LAUNCH_KERNEL, 230, 245, 3); - gpuOps->addKernelActivity(150, 170, 1); - gpuOps->addMemcpyActivity(240, 250, 2); - gpuOps->addKernelActivity(260, 320, 3); - cuptiActivities_.activityBuffer = std::move(gpuOps); - - // Have the profiler process them - auto logger = std::make_unique(*cfg_); - profiler.processTrace(*logger); - - // Profiler can be reset at this point - logger owns the activities - profiler_->reset(); - - // Wrapper that allows iterating over the activities - ActivityTrace trace(std::move(logger), loggerFactory); - EXPECT_EQ(trace.activities()->size(), 9); - std::map activityCounts; - std::map resourceIds; - for (auto& activity : *trace.activities()) { - activityCounts[activity->name()]++; - resourceIds[activity->resourceId()]++; - } - for (const auto& p : activityCounts) { - LOG(INFO) << p.first << ": " << p.second; - } - EXPECT_EQ(activityCounts["op1"], 1); - EXPECT_EQ(activityCounts["op2"], 1); - EXPECT_EQ(activityCounts["op3"], 1); - EXPECT_EQ(activityCounts["cudaLaunchKernel"], 2); - EXPECT_EQ(activityCounts["cudaMemcpy"], 1); - EXPECT_EQ(activityCounts["kernel"], 2); - EXPECT_EQ(activityCounts["Memcpy HtoD (Pinned -> Device)"], 1); - - auto sysTid = systemThreadId(); - // Ops and runtime events are on thread sysTid - EXPECT_EQ(resourceIds[sysTid], 6); - // Kernels are on stream 1, memcpy on stream 2 - EXPECT_EQ(resourceIds[1], 2); - EXPECT_EQ(resourceIds[2], 1); - -#ifdef __linux__ - char filename[] = "/tmp/libkineto_testXXXXXX.json"; - mkstemps(filename, 5); - trace.save(filename); - // Check that the expected file was written and that it has some content - int fd = open(filename, O_RDONLY); - if (!fd) { - perror(filename); - } - EXPECT_TRUE(fd); - // Should expect at least 100 bytes - struct stat buf{}; - fstat(fd, &buf); - EXPECT_GT(buf.st_size, 100); -#endif -} - -TEST_F(CuptiActivityProfilerTest, GpuUserAnnotationTest) { - // Verbose logging is useful for debugging - std::vector log_modules( - {"CuptiActivityProfiler.cpp"}); - SET_LOG_VERBOSITY_LEVEL(2, log_modules); - - // Start and stop profiling - CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); - int64_t start_time_us = 100; - int64_t duration_us = 300; - auto start_time = time_point(microseconds(start_time_us)); - profiler.configure(*cfg_, start_time); - profiler.startTrace(start_time); - profiler.stopTrace(start_time + microseconds(duration_us)); - - int64_t kernelLaunchTime = 120; - profiler.recordThreadInfo(); - - // set up CPU event - auto cpuOps = std::make_unique( - start_time_us, start_time_us + duration_us); - cpuOps->addOp("annotation", kernelLaunchTime, kernelLaunchTime + 10, 1); - profiler.transferCpuTrace(std::move(cpuOps)); - - // set up a couple of GPU events and correlate with above CPU event. - // CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1 is used for user annotations. - auto gpuOps = std::make_unique(); - gpuOps->addCorrelationActivity(1, CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, 1); - gpuOps->addKernelActivity(kernelLaunchTime + 5, kernelLaunchTime + 10, 1); - gpuOps->addCorrelationActivity(1, CUPTI_EXTERNAL_CORRELATION_KIND_CUSTOM1, 1); - gpuOps->addKernelActivity(kernelLaunchTime + 15, kernelLaunchTime + 25, 1); - cuptiActivities_.activityBuffer = std::move(gpuOps); - - // process trace - auto logger = std::make_unique(*cfg_); - profiler.processTrace(*logger); - - ActivityTrace trace(std::move(logger), loggerFactory); - std::map counts; - for (auto& activity : *trace.activities()) { - counts[activity->name()]++; - } - - // We should now have an additional annotation activity created - // on the GPU timeline. - EXPECT_EQ(counts["annotation"], 2); - EXPECT_EQ(counts["kernel"], 2); - - auto& annotation = trace.activities()->at(0); - auto& kernel1 = trace.activities()->at(1); - auto& kernel2 = trace.activities()->at(2); - auto& gpu_annotation = trace.activities()->at(3); - EXPECT_EQ(gpu_annotation->type(), ActivityType::GPU_USER_ANNOTATION); - EXPECT_EQ(gpu_annotation->timestamp(), kernel1->timestamp()); - EXPECT_EQ( - gpu_annotation->duration(), - kernel2->timestamp() + kernel2->duration() - kernel1->timestamp()); - EXPECT_EQ(gpu_annotation->deviceId(), kernel1->deviceId()); - EXPECT_EQ(gpu_annotation->resourceId(), kernel1->resourceId()); - EXPECT_EQ(gpu_annotation->correlationId(), annotation->correlationId()); - EXPECT_EQ(gpu_annotation->name(), annotation->name()); -} - -TEST_F(CuptiActivityProfilerTest, SubActivityProfilers) { - using ::testing::Return; - using ::testing::ByMove; - - // Verbose logging is useful for debugging - std::vector log_modules( - {"CuptiActivityProfiler.cpp"}); - SET_LOG_VERBOSITY_LEVEL(2, log_modules); - - // Setup example events to test - GenericTraceActivity ev{defaultTraceSpan(), ActivityType::GLOW_RUNTIME, ""}; - ev.device = 1; - ev.resource = 0; - - int64_t start_time_us = 100; - int64_t duration_us = 1000; - auto start_time = time_point(microseconds(start_time_us)); - - std::vector test_activities{3, ev}; - test_activities[0].startTime = start_time_us; - test_activities[0].endTime = start_time_us + 5000; - test_activities[0].activityName = "SubGraph A execution"; - test_activities[1].startTime = start_time_us; - test_activities[1].endTime = start_time_us + 2000; - test_activities[1].activityName = "Operator foo"; - test_activities[2].startTime = start_time_us + 2500; - test_activities[2].endTime = start_time_us + 2900; - test_activities[2].activityName = "Operator bar"; - - auto mock_activity_profiler = - std::make_unique(test_activities); - - MockCuptiActivities activities; - CuptiActivityProfiler profiler(activities, /*cpu only*/ true); - profiler.addChildActivityProfiler( - std::move(mock_activity_profiler)); - - profiler.configure(*cfg_, start_time); - profiler.startTrace(start_time); - EXPECT_TRUE(profiler.isActive()); - - profiler.stopTrace(start_time + microseconds(duration_us)); - EXPECT_TRUE(profiler.isActive()); - - char filename[] = "/tmp/libkineto_testXXXXXX.json"; - mkstemps(filename, 5); - LOG(INFO) << "Logging to tmp file " << filename; - - // process trace - auto logger = std::make_unique(*cfg_); - profiler.processTrace(*logger); - profiler.setLogger(logger.get()); - - ActivityTrace trace(std::move(logger), loggerFactory); - trace.save(filename); - const auto& traced_activites = trace.activities(); - - // Test we have all the events - EXPECT_EQ(traced_activites->size(), test_activities.size()); - - // Check that the expected file was written and that it has some content - int fd = open(filename, O_RDONLY); - if (!fd) { - perror(filename); - } - EXPECT_TRUE(fd); - - // Should expect at least 100 bytes - struct stat buf{}; - fstat(fd, &buf); - EXPECT_GT(buf.st_size, 100); -} - -TEST_F(CuptiActivityProfilerTest, BufferSizeLimitTestWarmup) { - CuptiActivityProfiler profiler(cuptiActivities_, /*cpu only*/ false); - - auto now = system_clock::now(); - auto startTime = now + seconds(10); - - int maxBufferSizeMB = 3; - - auto startTimeEpoch = std::to_string(duration_cast(startTime.time_since_epoch()).count()); - std::string maxBufferSizeMBStr = std::to_string(maxBufferSizeMB); - cfg_->handleOption("ACTIVITIES_MAX_GPU_BUFFER_SIZE_MB", maxBufferSizeMBStr); - cfg_->handleOption("PROFILE_START_TIME", startTimeEpoch); - - - EXPECT_FALSE(profiler.isActive()); - profiler.configure(*cfg_, now); - EXPECT_TRUE(profiler.isActive()); - - for (size_t i = 0; i < maxBufferSizeMB; i++) { - uint8_t* buf; - size_t gpuBufferSize; - size_t maxNumRecords; - cuptiActivities_.bufferRequestedOverride(&buf, &gpuBufferSize, &maxNumRecords); - } - - // fast forward to startTime and profiler is now running - now = startTime; - - profiler.performRunLoopStep(now, now); - - auto next = now + milliseconds(1000); - profiler.performRunLoopStep(next, next); - profiler.performRunLoopStep(next, next); - profiler.performRunLoopStep(next, next); - - EXPECT_FALSE(profiler.isActive()); -} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp deleted file mode 100644 index 253b696da54d1919e9c0076c5691a11e35345686..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp +++ /dev/null @@ -1,239 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "src/Logger.h" -#include "src/CuptiCallbackApi.h" - -#include -#include -#include -#include - -using namespace std::chrono; -using namespace KINETO_NAMESPACE; -using namespace libkineto; - -const size_t some_data = 42; - -std::atomic simple_cb_calls = 0; - -void simple_cb( - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - const CUpti_CallbackData* cbInfo) { - - // simple arg check - EXPECT_EQ(domain, CUPTI_CB_DOMAIN_RUNTIME_API); - EXPECT_EQ(cbid, CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000); - EXPECT_EQ(*reinterpret_cast(cbInfo), some_data); - - simple_cb_calls++; -} - -void atomic_cb( - CUpti_CallbackDomain /*domain*/, - CUpti_CallbackId /*cbid*/, - const CUpti_CallbackData* /*cbInfo)*/) { - // do some atomics in a loop - for (int i = 0; i < 1000; i++) { - // would have used release consistency but this is fine - simple_cb_calls++; - } -} - -void empty_cb( - CUpti_CallbackDomain /*domain*/, - CUpti_CallbackId /*cbid*/, - const CUpti_CallbackData* /*cbInfo*/) { -} - -TEST(CuptiCallbackApiTest, SimpleTest) { - auto& api = CuptiCallbackApi::singleton(); - - auto addSimpleCallback = [&]() -> bool { - bool ret = api.registerCallback( - CUPTI_CB_DOMAIN_RUNTIME_API, - CuptiCallbackApi::CUDA_LAUNCH_KERNEL, - &simple_cb - ); - return ret; - }; - EXPECT_TRUE(addSimpleCallback()) << "Failed to add callback"; - - // duplicate add should be okay - EXPECT_TRUE(addSimpleCallback()) << "Failed to re-add callback"; - - simple_cb_calls = 0; - - // simulate callback - api.__callback_switchboard( - CUPTI_CB_DOMAIN_RUNTIME_API, - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000, - reinterpret_cast(&some_data)); - - EXPECT_EQ(simple_cb_calls, 1); - - bool ret = api.deleteCallback( - CUPTI_CB_DOMAIN_RUNTIME_API, - CuptiCallbackApi::CUDA_LAUNCH_KERNEL, - &simple_cb - ); - - EXPECT_TRUE(ret) << "Failed to remove callback"; - - ret = api.deleteCallback( - CUPTI_CB_DOMAIN_RUNTIME_API, - CuptiCallbackApi::CUDA_LAUNCH_KERNEL, - &atomic_cb - ); - - EXPECT_FALSE(ret) << "oops! deleted a callback that was never added"; -} - -TEST(CuptiCallbackApiTest, AllCallbacks) { - auto& api = CuptiCallbackApi::singleton(); - - auto testCallback = [&]( - CUpti_CallbackDomain domain, - CUpti_CallbackId cbid, - CuptiCallbackApi::CuptiCallBackID kineto_cbid) -> bool { - - bool ret = api.registerCallback(domain, kineto_cbid, atomic_cb); - EXPECT_TRUE(ret) << "Failed to add callback"; - - if (!ret) { - return false; - } - - simple_cb_calls = 0; - api.__callback_switchboard(domain, cbid, nullptr); - EXPECT_EQ(simple_cb_calls, 1000); - ret = simple_cb_calls == 1000; - - EXPECT_TRUE(api.deleteCallback(domain, kineto_cbid, atomic_cb)); - - return ret; - }; - - EXPECT_TRUE( - testCallback( - CUPTI_CB_DOMAIN_RESOURCE, - CUPTI_CBID_RESOURCE_CONTEXT_CREATED, - CuptiCallbackApi::RESOURCE_CONTEXT_CREATED)) - << "Failed to run callback for RESOURCE_CONTEXT_CREATED"; - - EXPECT_TRUE( - testCallback( - CUPTI_CB_DOMAIN_RESOURCE, - CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING, - CuptiCallbackApi::RESOURCE_CONTEXT_DESTROYED)) - << "Failed to run callback for RESOURCE_CONTEXT_DESTROYED"; - - EXPECT_TRUE( - testCallback( - CUPTI_CB_DOMAIN_RUNTIME_API, - CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000, - CuptiCallbackApi::CUDA_LAUNCH_KERNEL)) - << "Failed to run callback for CUDA_LAUNCH_KERNEL"; - -} - -TEST(CuptiCallbackApiTest, ContentionTest) { - auto& api = CuptiCallbackApi::singleton(); - const CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RUNTIME_API; - const CUpti_CallbackId cbid = CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000; - const CuptiCallbackApi::CuptiCallBackID kineto_cbid = - CuptiCallbackApi::CUDA_LAUNCH_KERNEL; - - bool ret = api.registerCallback(domain, kineto_cbid, empty_cb); - EXPECT_TRUE(ret) << "Failed to add callback"; - - const int iters = 10000; - const int num_readers = 8; - - simple_cb_calls = 0; - - // simulate callbacks being executed on multiple threads in parallel - // during this interval add a new atomic_callback. - // this test ensured mutual exclusion is working fine - auto read_fn = [&](int tid){ - auto start_ts = high_resolution_clock::now(); - for (int i = 0; i < iters; i++) { - api.__callback_switchboard(domain, cbid, nullptr); - } - auto runtime_ms = duration_cast( - high_resolution_clock::now() - start_ts); - LOG(INFO) << "th " << tid << " done in " << runtime_ms.count() << " ms"; - }; - - - std::vector read_ths; - for (int i = 0; i< num_readers; i++) { - read_ths.emplace_back(read_fn, i); - } - - ret = api.registerCallback(domain, kineto_cbid, atomic_cb); - EXPECT_TRUE(ret) << "Failed to add callback"; - - for (auto& t : read_ths) { - t.join(); - } - - //EXPECT_GT(simple_cb_calls, 0) - // << "Atomic callback should have been called at least once."; - - api.deleteCallback(domain, kineto_cbid, empty_cb); - api.deleteCallback(domain, kineto_cbid, atomic_cb); -} - -TEST(CuptiCallbackApiTest, Bechmark) { - - constexpr int iters = 1000; - // atomic bench a number of times to get a baseline - - const CUpti_CallbackDomain domain = CUPTI_CB_DOMAIN_RUNTIME_API; - const CUpti_CallbackId cbid = CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000; - const CuptiCallbackApi::CuptiCallBackID kineto_cbid = - CuptiCallbackApi::CUDA_LAUNCH_KERNEL; - - LOG(INFO) << "Iteration count = " << iters; - - const bool use_empty = true; - auto cbfn = use_empty ? &empty_cb : &atomic_cb; - - // warmup - for (int i = 0; i < 50; i++) { - (*cbfn)(domain, cbid, nullptr); - } - - auto start_ts = high_resolution_clock::now(); - for (int i = 0; i < iters; i++) { - (*cbfn)(domain, cbid, nullptr); - } - auto delta_baseline_ns = duration_cast( - high_resolution_clock::now() - start_ts); - LOG(INFO) << "Baseline runtime = " << delta_baseline_ns.count() << " ns"; - - - auto& api = CuptiCallbackApi::singleton(); - bool ret = api.registerCallback(domain, kineto_cbid, cbfn); - EXPECT_TRUE(ret) << "Failed to add callback"; - - // warmup - for (int i = 0; i < 50; i++) { - api.__callback_switchboard(domain, cbid, nullptr); - } - - start_ts = high_resolution_clock::now(); - for (int i = 0; i < iters; i++) { - api.__callback_switchboard(domain, cbid, nullptr); - } - - auto delta_callback_ns = duration_cast( - high_resolution_clock::now() - start_ts); - LOG(INFO) << "Callback runtime = " << delta_callback_ns.count() << " ns"; - - LOG(INFO) << "Callback runtime per iteration = " << - (delta_callback_ns.count() - delta_baseline_ns.count()) / (double) iters - << " ns"; - -} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu b/plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu deleted file mode 100644 index 54ad51b0a1fc9a6a54585d1cad4674943c874b98..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu +++ /dev/null @@ -1,353 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include -#include - -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "src/Logger.h" -#include "src/CuptiRangeProfilerApi.h" - -#define DRIVER_API_CALL(apiFuncCall) \ - do { \ - CUresult _status = apiFuncCall; \ - if (_status != CUDA_SUCCESS) { \ - LOG(ERROR) << "Failed invoking CUDA driver function " \ - << #apiFuncCall << " status = " \ - << _status; \ - exit(-1); \ - } \ - } while (0) - -#define EXPECT(expr)\ - if (!(expr)) {\ - }; - -using namespace KINETO_NAMESPACE; - -static int numRanges = 1; - -using Type = double; - -// Device code -__global__ void VecAdd(const Type* A, const Type* B, Type* C, int N) { - int i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < N) { - C[i] = A[i] + B[i]; - } -} - -// Device code -__global__ void VecSub(const Type* A, const Type* B, Type* C, int N) { - int i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < N) { - C[i] = A[i] - B[i]; - } -} - -static void initVec(Type* vec, int n) { - for (int i = 0; i < n; i++) { - vec[i] = i; - } -} - -static void cleanUp( - Type* h_A, - Type* h_B, - Type* h_C, - Type* h_D, - Type* d_A, - Type* d_B, - Type* d_C, - Type* d_D) { - if (d_A) - cudaFree(d_A); - if (d_B) - cudaFree(d_B); - if (d_C) - cudaFree(d_C); - if (d_D) - cudaFree(d_D); - - // Free host memory - if (h_A) - free(h_A); - if (h_B) - free(h_B); - if (h_C) - free(h_C); - if (h_D) - free(h_D); -} - -/* Benchmark application used to test profiler measurements - * This simply runs two kernels vector Add and Vector Subtract - */ - -void VectorAddSubtract() { - int N = 50000; - size_t size = N * sizeof(Type); - int threadsPerBlock = 0; - int blocksPerGrid = 0; - Type *h_A, *h_B, *h_C, *h_D; - Type *d_A, *d_B, *d_C, *d_D; - int i; - Type sum, diff; - - // Allocate input vectors h_A and h_B in host memory - h_A = (Type*)malloc(size); - h_B = (Type*)malloc(size); - h_C = (Type*)malloc(size); - h_D = (Type*)malloc(size); - - // Initialize input vectors - initVec(h_A, N); - initVec(h_B, N); - memset(h_C, 0, size); - memset(h_D, 0, size); - - // Allocate vectors in device memory - cudaMalloc((void**)&d_A, size); - cudaMalloc((void**)&d_B, size); - cudaMalloc((void**)&d_C, size); - cudaMalloc((void**)&d_D, size); - - // Copy vectors from host memory to device memory - cudaMemcpy(d_A, h_A, size, cudaMemcpyHostToDevice); - cudaMemcpy(d_B, h_B, size, cudaMemcpyHostToDevice); - - // Invoke kernel - threadsPerBlock = 256; - blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock; - LOG(INFO) << fmt::format( - "Launching kernel: blocks {}, thread/block {}", - blocksPerGrid, - threadsPerBlock); - - VecAdd<<>>(d_A, d_B, d_C, N); - - VecSub<<>>(d_A, d_B, d_D, N); - - // Copy result from device memory to host memory - // h_C contains the result in host memory - cudaMemcpy(h_C, d_C, size, cudaMemcpyDeviceToHost); - cudaMemcpy(h_D, d_D, size, cudaMemcpyDeviceToHost); - - // Verify result - for (i = 0; i < N; ++i) { - sum = h_A[i] + h_B[i]; - diff = h_A[i] - h_B[i]; - if (h_C[i] != sum || h_D[i] != diff) { - LOG(ERROR) << "Result verification failed"; - break; - } - } - - cleanUp(h_A, h_B, h_C, h_D, d_A, d_B, d_C, d_D); -} - -#if HAS_CUPTI_RANGE_PROFILER -bool runTestWithAutoRange( - int deviceNum, - const std::vector& metricNames, - CUcontext cuContext, - bool async) { - - // create a CUPTI range based profiling profiler - // this configures the counter data as well - CuptiRBProfilerSession profiler( - metricNames, deviceNum, 2, 1, async ? nullptr : cuContext); - - CUpti_ProfilerRange profilerRange = CUPTI_AutoRange; - CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_KernelReplay; - - if (async) { - profiler.asyncStartAndEnable(profilerRange, profilerReplayMode); - } else { - profiler.start(profilerRange, profilerReplayMode); - profiler.enable(); - } - - VectorAddSubtract(); - - if (!async) { - profiler.disable(); - // stop profiler - profiler.stop(); - } else { - profiler.asyncDisableAndStop(); - } - - auto result = profiler.evaluateMetrics(true); - - // check results - EXPECT_EQ(result.metricNames.size(), 3); - EXPECT_EQ(result.rangeVals.size(), 2); - - for (const auto& measurement : result.rangeVals) { - EXPECT_EQ(measurement.values.size(), 3); - - if (measurement.values.size() == 3) { - // smsp__warps_launched.avg - EXPECT_NE(measurement.values[0], 0); - // smsp__sass_thread_inst_executed_op_dadd_pred_on.sum - // each kernel has 50000 dadd ops - EXPECT_EQ(measurement.values[1], 50000); - // sm__inst_executed_pipe_tensor.sum - //EXPECT_EQ(measurement.values[2], 0); - } - } - return true; -} - -bool runTestWithUserRange( - int deviceNum, - const std::vector& metricNames, - CUcontext cuContext, - bool async = false) { - - // create a CUPTI range based profiling profiler - // this configures the counter data as well - CuptiRBProfilerSession profiler( - metricNames, deviceNum, numRanges, 1, async ? nullptr : cuContext); - - CUpti_ProfilerRange profilerRange = CUPTI_UserRange; - CUpti_ProfilerReplayMode profilerReplayMode = CUPTI_UserReplay; - - if (async) { - profiler.asyncStartAndEnable(profilerRange, profilerReplayMode); - { VectorAddSubtract(); } - profiler.disableAndStop(); - } else { - profiler.start(profilerRange, profilerReplayMode); - - /* User takes the resposiblity of replaying the kernel launches */ - bool replay = true; - do { - profiler.beginPass(); - { - profiler.enable(); - - std::string rangeName = "vecAddSub"; - profiler.pushRange(rangeName); - - { VectorAddSubtract(); } - - profiler.popRange(); - profiler.disable(); - } - LOG(INFO) << "Replay starting."; - replay = profiler.endPass(); - - } while (!replay); - - // stop profiler - profiler.stop(); - } - VectorAddSubtract(); - auto result = profiler.evaluateMetrics(true); - - // check results - EXPECT_EQ(result.metricNames.size(), 3); - EXPECT_EQ(result.rangeVals.size(), 1); - - if (result.rangeVals.size() > 0) { - const auto& measurement = result.rangeVals[0]; - EXPECT_EQ(measurement.values.size(), 3); - - if (measurement.values.size() == 3) { - // smsp__warps_launched.avg - EXPECT_NE(measurement.values[0], 0); - // smsp__sass_thread_inst_executed_op_dadd_pred_on.sum - // in async mode multiple passes are not supported yet - if (!async) { - EXPECT_EQ(measurement.values[1], 100000); - } - // sm__inst_executed_pipe_tensor.sum - //EXPECT_EQ(measurement.values[2], 0); - } - } - return true; -} -#endif // HAS_CUPTI_RANGE_PROFILER - -int main(int argc, char* argv[]) { - - CUdevice cuDevice; - - int deviceCount, deviceNum; - int computeCapabilityMajor = 0, computeCapabilityMinor = 0; - - printf("Usage: %s [device_num]\n", argv[0]); - - DRIVER_API_CALL(cuInit(0)); - DRIVER_API_CALL(cuDeviceGetCount(&deviceCount)); - - if (deviceCount == 0) { - LOG(ERROR) << "There is no device supporting CUDA."; - return -2; - } - - if (argc > 1) - deviceNum = atoi(argv[1]); - else - deviceNum = 0; - LOG(INFO) << "CUDA Device Number: " << deviceNum; - - DRIVER_API_CALL(cuDeviceGet(&cuDevice, deviceNum)); - DRIVER_API_CALL(cuDeviceGetAttribute( - &computeCapabilityMajor, - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, - cuDevice)); - DRIVER_API_CALL(cuDeviceGetAttribute( - &computeCapabilityMinor, - CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, - cuDevice)); - - LOG(INFO) << "Compute Cabapbility = " - << fmt::format("{},{}",computeCapabilityMajor, computeCapabilityMinor); - - if (computeCapabilityMajor < 7) { - LOG(ERROR) << "CUPTI Profiler is not supported with compute capability < 7.0"; - return -2; - } - - CuptiRBProfilerSession::staticInit(); - - // metrics to profile - std::vector metricNames = { - "smsp__warps_launched.avg", - "smsp__sass_thread_inst_executed_op_dadd_pred_on.sum", - "sm__inst_executed_pipe_tensor.sum", - }; - - CUcontext cuContext; - DRIVER_API_CALL(cuCtxCreate(&cuContext, 0, cuDevice)); - - VectorAddSubtract(); - -#if HAS_CUPTI_RANGE_PROFILER - CuptiRBProfilerSession::staticInit(); - - if (!runTestWithUserRange(deviceNum, metricNames, cuContext, false)) { - LOG(ERROR) << "Failed to profiler test benchmark in user range"; - } else if (!runTestWithAutoRange(deviceNum, metricNames, cuContext, false)) { - LOG(ERROR) << "Failed to profiler test benchmark in auto range"; - } else if (!runTestWithUserRange(deviceNum, metricNames, cuContext, true)) { - LOG(ERROR) << "Failed to profiler test benchmark in user range async"; - } else if (!runTestWithAutoRange(deviceNum, metricNames, cuContext, true)) { - LOG(ERROR) << "Failed to profiler test benchmark in auto range async"; - } - - CuptiRBProfilerSession::deInitCupti(); -#else - LOG(WARNING) << "CuptiRBProfilerSession is not supported."; -#endif // HAS_CUPTI_RANGE_PROFILER - DRIVER_API_CALL(cuCtxDestroy(cuContext)); - - - return 0; -} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp deleted file mode 100644 index 28cad722c53ee5defaa7c24cbe0d6b2cbc840a30..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp +++ /dev/null @@ -1,113 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include -#include - -#include "include/libkineto.h" -#include "include/Config.h" -#include "src/CuptiRangeProfilerApi.h" - -#include "src/Logger.h" -#include "test/CuptiRangeProfilerTestUtil.h" - -using namespace KINETO_NAMESPACE; - -#if HAS_CUPTI_PROFILER - -TEST(CuptiRangeProfilerApiTest, contextTracking) { - std::vector log_modules( - {"CuptiRangeProfilerApi.cpp"}); - SET_LOG_VERBOSITY_LEVEL(1, log_modules); - - std::array data; - std::array contexts; - for (int i = 0; i < data.size(); i++) { - contexts[i] = reinterpret_cast(&data[i]); - } - - // simulate creating contexts, this calls the trackCudaContexts - // function that would otherwise be called via a callback - uint32_t dev = 0; - for (auto ctx : contexts) { - simulateCudaContextCreate(ctx, dev++); - } - - EXPECT_EQ( - CuptiRBProfilerSession::getActiveDevices(), - std::set({0, 1, 2})); - - simulateCudaContextDestroy(contexts[1], 1); - - EXPECT_EQ( - CuptiRBProfilerSession::getActiveDevices(), - std::set({0, 2})); - - simulateCudaContextDestroy(contexts[0], 0); - simulateCudaContextDestroy(contexts[2], 2); - - EXPECT_TRUE( - CuptiRBProfilerSession::getActiveDevices().empty()); -} - -TEST(CuptiRangeProfilerApiTest, asyncLaunchUserRange) { - std::vector log_modules( - {"CuptiRangeProfilerApi.cpp"}); - SET_LOG_VERBOSITY_LEVEL(1, log_modules); - - // this is bad but the pointer is never accessed - CUcontext ctx0 = reinterpret_cast(10); - simulateCudaContextCreate(ctx0, 0 /*device_id*/); - - auto session = std::make_unique(0, ctx0); - session->asyncStartAndEnable(CUPTI_UserRange, CUPTI_UserReplay); - - simulateKernelLaunch(ctx0, "hello"); - simulateKernelLaunch(ctx0, "foo"); - simulateKernelLaunch(ctx0, "bar"); - - session->asyncDisableAndStop(); - // stop happens after next kernel is run - simulateKernelLaunch(ctx0, "bar"); - simulateCudaContextDestroy(ctx0, 0 /*device_id*/); - - EXPECT_EQ(session->passes_ended, 1); - EXPECT_EQ(session->ranges_ended, 1); - EXPECT_TRUE(session->enabled); -} - -TEST(CuptiRangeProfilerApiTest, asyncLaunchAutoRange) { - std::vector log_modules( - {"CuptiRangeProfilerApi.cpp"}); - SET_LOG_VERBOSITY_LEVEL(1, log_modules); - - // this is bad but the pointer is never accessed - CUcontext ctx0 = reinterpret_cast(10); - CUcontext ctx1 = reinterpret_cast(11); - - simulateCudaContextCreate(ctx0, 0 /*device_id*/); - - auto session = std::make_unique(0, ctx0); - session->asyncStartAndEnable(CUPTI_AutoRange, CUPTI_KernelReplay); - - simulateKernelLaunch(ctx0, "hello"); - simulateKernelLaunch(ctx0, "foo"); - simulateKernelLaunch(ctx1, "kernel_on_different_device"); - simulateKernelLaunch(ctx0, "bar"); - - session->asyncDisableAndStop(); - // stop happens after next kernel is run - simulateKernelLaunch(ctx0, "bar"); - simulateCudaContextDestroy(ctx0, 0 /*device_id*/); - - EXPECT_EQ(session->passes_ended, 0); - EXPECT_EQ(session->ranges_ended, 0); - EXPECT_TRUE(session->enabled); - - EXPECT_EQ( - session->getKernelNames(), - std::vector({"hello", "foo", "bar"})) - << "Kernel names were not tracked"; -} - -#endif // HAS_CUPTI_PROFILER diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp deleted file mode 100644 index 3f568968238a0e376ab3bae621af00a162af0d25..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "include/Config.h" -#include "src/CuptiRangeProfilerConfig.h" - -#include -#include -#include -#include - -using namespace std::chrono; -using namespace KINETO_NAMESPACE; - -class CuptiRangeProfilerConfigTest : public ::testing::Test { - protected: - void SetUp() override { - CuptiRangeProfilerConfig::registerFactory(); - } -}; - -TEST_F(CuptiRangeProfilerConfigTest, ConfigureProfiler) { - Config cfg; - std::vector metrics = { - "kineto__cuda_core_flops", - "sm__inst_executed.sum", - "l1tex__data_bank_conflicts_pipe_lsu.sum", - }; - auto metricsConfigStr = - fmt::format("CUPTI_PROFILER_METRICS = {}", fmt::join(metrics, ",")); - - EXPECT_TRUE(cfg.parse(metricsConfigStr)); - EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_ENABLE_PER_KERNEL = true")); - EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_MAX_RANGES = 42")); - - const CuptiRangeProfilerConfig& cupti_cfg = - CuptiRangeProfilerConfig::get(cfg); - - EXPECT_EQ(cupti_cfg.activitiesCuptiMetrics(), metrics); - EXPECT_EQ(cupti_cfg.cuptiProfilerPerKernel(), true); - EXPECT_EQ(cupti_cfg.cuptiProfilerMaxRanges(), 42); - -} - -TEST_F(CuptiRangeProfilerConfigTest, RangesDefaults) { - Config cfg, cfg_auto; - - // do not set max ranges in config, check defaults are sane - EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_METRICS = kineto__cuda_core_flops")); - EXPECT_TRUE(cfg.parse("CUPTI_PROFILER_ENABLE_PER_KERNEL = false")); - - cfg.setSignalDefaults(); - - EXPECT_TRUE(cfg_auto.parse("CUPTI_PROFILER_METRICS = kineto__cuda_core_flops")); - EXPECT_TRUE(cfg_auto.parse("CUPTI_PROFILER_ENABLE_PER_KERNEL = true")); - - cfg_auto.setClientDefaults(); - - int user_ranges, auto_ranges; - - user_ranges = CuptiRangeProfilerConfig::get(cfg).cuptiProfilerMaxRanges(); - auto_ranges = CuptiRangeProfilerConfig::get(cfg_auto).cuptiProfilerMaxRanges(); - - EXPECT_GE(user_ranges, 1) << " in user range mode default to at least 1 ranges"; - EXPECT_GE(auto_ranges, 1000) << " in auto range mode default to at least 1000 ranges"; - - EXPECT_GT(auto_ranges, user_ranges); -} diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h b/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h deleted file mode 100644 index 861b65fd701bf69373df657ab2a22d9dba0b27df..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h +++ /dev/null @@ -1,96 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "CuptiRangeProfilerApi.h" - -namespace KINETO_NAMESPACE { - -#if HAS_CUPTI_PROFILER - -class MockCuptiRBProfilerSession : public CuptiRBProfilerSession { - public: - MockCuptiRBProfilerSession(int deviceId, CUcontext ctx) - : CuptiRBProfilerSession(deviceId, ctx) {} - - void beginPass() override { - LOG(INFO) << " Mock CUPTI begin pass"; - passes_started++; - } - - bool endPass() override { - passes_ended++; - return true; - } - - void flushCounterData() override {} - - void pushRange(const std::string& rangeName) override { - LOG(INFO) << " Mock CUPTI pushrange ( " << rangeName << " )"; - ranges_started++; - } - - void popRange() override { - LOG(INFO) << " Mock CUPTI poprange"; - ranges_ended++; - } - - void stop() override { - runChecks(); - } - - void enable() override { - enabled = true; - } - void disable() override {} - - CuptiProfilerResult evaluateMetrics(bool /*verbose*/) override { - return result; - } - -protected: - void startInternal( - CUpti_ProfilerRange profilerRange, - CUpti_ProfilerReplayMode profilerReplayMode) override { - curRange_ = profilerRange; - curReplay_ = profilerReplayMode; - } - -private: - void runChecks() { - EXPECT_EQ(passes_started, passes_ended); - EXPECT_EQ(ranges_started, ranges_ended); - } - - public: - int passes_started = 0; - int passes_ended = 0; - int ranges_started = 0; - int ranges_ended = 0; - bool enabled = false; - - CuptiProfilerResult result; - -}; - -inline void simulateCudaContextCreate(CUcontext context, uint32_t dev) { - testing::trackCudaCtx( - context, dev, CUPTI_CBID_RESOURCE_CONTEXT_CREATED); -} - -inline void simulateCudaContextDestroy(CUcontext context, uint32_t dev) { - testing::trackCudaCtx( - context, dev, CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); -} - -inline void simulateKernelLaunch( - CUcontext context, const std::string& kernelName) { - testing::trackCudaKernelLaunch(context, kernelName.c_str()); -} - -#endif // HAS_CUPTI_PROFILER - -} // namespace KINETO_NAMESPACE diff --git a/plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp b/plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp deleted file mode 100644 index 405f9404a49a5bf8b7433930b0ad2fe898ea2d89..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp +++ /dev/null @@ -1,29 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include - -#include "src/cupti_strings.h" - -using namespace KINETO_NAMESPACE; - -TEST(CuptiStringsTest, Valid) { - ASSERT_STREQ( - runtimeCbidName(CUPTI_RUNTIME_TRACE_CBID_INVALID), "INVALID"); - ASSERT_STREQ( - runtimeCbidName(CUPTI_RUNTIME_TRACE_CBID_cudaDriverGetVersion_v3020), - "cudaDriverGetVersion"); - ASSERT_STREQ(runtimeCbidName - (CUPTI_RUNTIME_TRACE_CBID_cudaDeviceSynchronize_v3020), - "cudaDeviceSynchronize"); - ASSERT_STREQ( - runtimeCbidName(CUPTI_RUNTIME_TRACE_CBID_cudaStreamSetAttribute_ptsz_v11000), - "cudaStreamSetAttribute_ptsz"); -} - -TEST(CuptiStringsTest, Invalid) { - ASSERT_STREQ(runtimeCbidName(-1), "INVALID"); - // We can't actually use CUPTI_RUNTIME_TRACE_CBID_SIZE here until we - // auto-generate the string table, since it may have more entries than - // the enum in the version used to compile. - ASSERT_STREQ(runtimeCbidName(1000), "INVALID"); -} diff --git a/plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp b/plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp deleted file mode 100644 index cb36c826a7f32b2fe6732e73eae3b6a006b0cd3d..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp +++ /dev/null @@ -1,578 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "src/EventProfiler.h" - -#include -#include -#include - -using namespace std::chrono; -using namespace KINETO_NAMESPACE; - -TEST(PercentileTest, Create) { - PercentileList pct = {{10, SampleValue(0)}, - {49, SampleValue(0)}, - {50, SampleValue(0)}, - {90, SampleValue(0)}}; - - percentiles({0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, pct); - EXPECT_EQ(pct[0].second.getInt(), 10); - EXPECT_EQ(pct[1].second.getInt(), 50); - EXPECT_EQ(pct[2].second.getInt(), 50); - EXPECT_EQ(pct[3].second.getInt(), 90); - - percentiles({80, 10, 20, 70, 60, 40, 90, 30, 50, 0, 100}, pct); - EXPECT_EQ(pct[0].second.getInt(), 10); - EXPECT_EQ(pct[1].second.getInt(), 50); - EXPECT_EQ(pct[2].second.getInt(), 50); - EXPECT_EQ(pct[3].second.getInt(), 90); - - percentiles({80}, pct); - EXPECT_EQ(pct[0].second.getInt(), 80); - EXPECT_EQ(pct[1].second.getInt(), 80); - EXPECT_EQ(pct[2].second.getInt(), 80); - EXPECT_EQ(pct[3].second.getInt(), 80); - - percentiles({80, 50}, pct); - EXPECT_EQ(pct[0].second.getInt(), 50); - EXPECT_EQ(pct[1].second.getInt(), 50); - EXPECT_EQ(pct[2].second.getInt(), 80); - EXPECT_EQ(pct[3].second.getInt(), 80); -} - -TEST(PercentileTest, Normalize) { - PercentileList pct = { - {10, SampleValue(10)}, {50, SampleValue(100.0)}, {90, SampleValue(2000)}}; - - normalize(pct, 2.5); - - EXPECT_EQ(pct[0].second.getInt(), 25); - EXPECT_EQ((int)pct[1].second.getDouble(), 250); - EXPECT_EQ(pct[2].second.getInt(), 5000); -} - -TEST(EventTest, SumSamples) { - Event ev; - ev.instanceCount = 4; - auto t = system_clock::now(); - ev.addSample(t, {1, 2, 3, 4}); - ev.addSample(t, {10, 20, 30, 40}); - ev.addSample(t, {100, 200, 300, 400}); - - EXPECT_EQ(ev.sumInstance(0, {0, 0, 3}), 1); - EXPECT_EQ(ev.sumInstance(0, {0, 1, 3}), 10); - EXPECT_EQ(ev.sumInstance(0, {0, 2, 3}), 100); - - EXPECT_EQ(ev.sumInstance(0, {0, 0, 1}), 111); - - EXPECT_EQ(ev.sumInstance(3, {0, 0, 1}), 444); - - // Non-zero offset - EXPECT_EQ(ev.sumInstance(0, {1, 0, 2}), 10); - EXPECT_EQ(ev.sumInstance(0, {1, 1, 2}), 100); - EXPECT_EQ(ev.sumInstance(0, {1, 0, 1}), 110); - - ev.addSample(t, {1000, 2000, 3000, 4000}); - - EXPECT_EQ(ev.sumInstance(0, {1, 0, 3}), 10); - EXPECT_EQ(ev.sumInstance(0, {1, 1, 3}), 100); - EXPECT_EQ(ev.sumInstance(0, {2, 1, 2}), 1000); - EXPECT_EQ(ev.sumInstance(0, {2, 0, 1}), 1100); - - EXPECT_EQ(ev.sumAll({0, 0, 4}), 10); - EXPECT_EQ(ev.sumAll({1, 0, 3}), 100); - EXPECT_EQ(ev.sumAll({2, 1, 2}), 10000); - EXPECT_EQ(ev.sumAll({0, 1, 2}), 11000); - EXPECT_EQ(ev.sumAll({0, 0, 1}), 11110); -} - -TEST(EventTest, Percentiles) { - Event ev; - ev.instanceCount = 4; - auto t = system_clock::now(); - ev.addSample(t, {3, 2, 1, 4}); - ev.addSample(t, {30, 20, 10, 40}); - ev.addSample(t, {300, 200, 100, 400}); - - PercentileList pct = { - {10, SampleValue(0)}, {50, SampleValue(0)}, {90, SampleValue(0)}}; - - ev.percentiles(pct, {0, 0, 3}); - EXPECT_EQ(pct[0].second.getInt(), 1); - EXPECT_EQ(pct[1].second.getInt(), 3); - EXPECT_EQ(pct[2].second.getInt(), 4); - - ev.percentiles(pct, {0, 0, 1}); - EXPECT_EQ(pct[0].second.getInt(), 111); - EXPECT_EQ(pct[1].second.getInt(), 333); - EXPECT_EQ(pct[2].second.getInt(), 444); -} - -class MockCuptiMetrics : public CuptiMetricApi { - public: - MockCuptiMetrics() : CuptiMetricApi(0) {} - MOCK_METHOD1(idFromName, CUpti_MetricID(const std::string& name)); - MOCK_METHOD1( - events, - std::map(CUpti_MetricID metric_id)); - MOCK_METHOD1(valueKind, CUpti_MetricValueKind(CUpti_MetricID metric)); - MOCK_METHOD1( - evaluationMode, - CUpti_MetricEvaluationMode(CUpti_MetricID metric)); - MOCK_METHOD5( - calculate, - SampleValue( - CUpti_MetricID metric, - CUpti_MetricValueKind kind, - std::vector& events, - std::vector& values, - int64_t duration)); -}; - -TEST(MetricTest, Calculate) { - using ::testing::Return; - MockCuptiMetrics metrics; - - // The events used for the ipc metrics: instructions and cycles - // Pretend we have 2 SMs and 2 samples of each event - Event instr("instructions"); - instr.instanceCount = 2; - auto t = system_clock::now(); - instr.addSample(t, {100, 200}); - instr.addSample(t, {300, 400}); - - Event cycles("cycles"); - cycles.instanceCount = 2; - cycles.addSample(t, {1000, 1200}); - cycles.addSample(t, {1300, 1300}); - - // 2 & 3 are the event ids we specified in the metric - std::map events; - events[2] = std::move(instr); - events[3] = std::move(cycles); - - // Define an ipc metric - EXPECT_CALL(metrics, valueKind(1)) - .Times(1) - .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); - Metric m( - "ipc", 1, {2, 3}, CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE, metrics); - - // Calculate metric for first sample - // Since evaluation mode is CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE, - // Cupti API will be called three times: once for each SM (2) and once - // to get the total across SMs. - std::vector ids = {2, 3}; - std::vector vals = {100, 1000}; - EXPECT_CALL( - metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) - .Times(1) - .WillOnce(Return(SampleValue(0.1))); - vals = {200, 1200}; - EXPECT_CALL( - metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) - .Times(1) - .WillOnce(Return(SampleValue(0.17))); - vals = {300, 2200}; - EXPECT_CALL( - metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) - .Times(1) - .WillOnce(Return(SampleValue(0.14))); - auto v = m.calculate(events, nanoseconds(1000), {0, 0, 2}); - - EXPECT_EQ(v.perInstance.size(), 2); - EXPECT_EQ(v.perInstance[0].getDouble(), 0.1); - EXPECT_EQ(v.perInstance[1].getDouble(), 0.17); - EXPECT_EQ(v.total.getDouble(), 0.14); - - // Calculate second sample. - // Change evaluation mode to CUPTI_METRIC_EVALUATION_MODE_AGGREGATE. - // Now we should get only one call to the Cupti API for the total. - EXPECT_CALL(metrics, valueKind(1)) - .Times(1) - .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); - Metric m2("ipc", 1, {2, 3}, CUPTI_METRIC_EVALUATION_MODE_AGGREGATE, metrics); - vals = {700, 2600}; - EXPECT_CALL( - metrics, calculate(1, CUPTI_METRIC_VALUE_KIND_DOUBLE, ids, vals, 1000)) - .Times(1) - .WillOnce(Return(SampleValue(0.27))); - v = m2.calculate(events, nanoseconds(1000), {0, 1, 2}); - - EXPECT_EQ(v.perInstance.size(), 1); - EXPECT_EQ(v.perInstance[0].getDouble(), 0.27); - EXPECT_EQ(v.total.getDouble(), 0.27); -} - -class MockCuptiEvents : public CuptiEventApi { - public: - MOCK_METHOD1( - createGroupSets, - CUpti_EventGroupSets*(std::vector& ids)); - MOCK_METHOD1(destroyGroupSets, void(CUpti_EventGroupSets* sets)); - MOCK_METHOD0(setContinuousMode, bool()); - MOCK_METHOD1(enablePerInstance, void(CUpti_EventGroup eventGroup)); - MOCK_METHOD1(instanceCount, uint32_t(CUpti_EventGroup eventGroup)); - MOCK_METHOD1(enableGroupSet, void(CUpti_EventGroupSet& set)); - MOCK_METHOD1(disableGroupSet, void(CUpti_EventGroupSet& set)); - MOCK_METHOD3( - readEvent, - void(CUpti_EventGroup g, CUpti_EventID id, std::vector& vals)); - MOCK_METHOD1(eventsInGroup, std::vector(CUpti_EventGroup g)); - MOCK_METHOD1(eventId, CUpti_EventID(const std::string& name)); -}; - -TEST(EventGroupSetTest, CollectSample) { - using ::testing::_; - using ::testing::Return; - using ::testing::SetArgPointee; - const CUpti_EventGroup g1{nullptr}; - const CUpti_EventGroup g2{reinterpret_cast(0x1000)}; - CUpti_EventGroup groups[] = {g1, g2}; - CUpti_EventGroupSet set; - set.eventGroups = groups; - set.numEventGroups = 2; - - std::map events; - Event instr("instructions"); - events[4] = std::move(instr); - Event cycles("cycles"); - events[5] = std::move(cycles); - Event branches("branches"); - events[10] = std::move(branches); - - MockCuptiEvents cupti_events; - EXPECT_CALL(cupti_events, enablePerInstance(g1)).Times(1); - EXPECT_CALL(cupti_events, enablePerInstance(g2)).Times(1); - EXPECT_CALL(cupti_events, instanceCount(g1)).Times(1).WillOnce(Return(80)); - EXPECT_CALL(cupti_events, instanceCount(g2)).Times(1).WillOnce(Return(40)); - std::vector events_in_group1 = {4, 5}; - EXPECT_CALL(cupti_events, eventsInGroup(g1)) - .Times(1) - .WillOnce(Return(events_in_group1)); - std::vector events_in_group2 = {10}; - EXPECT_CALL(cupti_events, eventsInGroup(g2)) - .Times(1) - .WillOnce(Return(events_in_group2)); - EventGroupSet group_set(set, events, cupti_events); - - EXPECT_EQ(group_set.groupCount(), 2); - EXPECT_EQ(events[4].instanceCount, 80); - EXPECT_EQ(events[5].instanceCount, 80); - EXPECT_EQ(events[10].instanceCount, 40); - - // This should not cause any Cupti API action as the group - // set is already disabled - group_set.setEnabled(false); - - // Activate group set - if activated twice, only the first - // should cause cupti API to be called - EXPECT_CALL(cupti_events, enableGroupSet(_)).Times(1); - group_set.setEnabled(false); - group_set.setEnabled(true); - - EXPECT_CALL(cupti_events, eventsInGroup(g1)) - .Times(1) - .WillOnce(Return(events_in_group1)); - EXPECT_CALL(cupti_events, eventsInGroup(g2)) - .Times(1) - .WillOnce(Return(events_in_group2)); - EXPECT_CALL(cupti_events, readEvent(g1, 4, _)).Times(1); - EXPECT_CALL(cupti_events, readEvent(g1, 5, _)).Times(1); - EXPECT_CALL(cupti_events, readEvent(g2, 10, _)).Times(1); - group_set.collectSample(); - - EXPECT_EQ(events[4].sampleCount(), 1); - EXPECT_EQ(events[5].sampleCount(), 1); - EXPECT_EQ(events[10].sampleCount(), 1); -} - -class MockLogger : public SampleListener { - public: - MOCK_METHOD3(handleSample, void(int device, const Sample& sample, bool from_new_version)); - MOCK_METHOD1(update, void(const Config& config)); -}; - -class EventProfilerTest : public ::testing::Test { - protected: - void SetUp() override { - auto cupti_events_ptr = std::make_unique(); - auto cupti_metrics_ptr = std::make_unique(); - cuptiEvents_ = cupti_events_ptr.get(); - cuptiMetrics_ = cupti_metrics_ptr.get(); - loggers_.push_back(std::make_unique()); - onDemandLoggers_.push_back(std::make_unique()); - profiler_ = std::make_unique( - std::move(cupti_events_ptr), - std::move(cupti_metrics_ptr), - loggers_, - onDemandLoggers_); - - for (int i = 0; i < kEventGroupCount; i++) { - eventGroups_[i] = &eventGroups_[i]; - } - for (int i = 0; i < kGroupSetCount; i++) { - // Default size to 1 but can be changed by test - groupSet_[i].numEventGroups = 1; - // Two groups per set - groupSet_[i].eventGroups = &eventGroups_[i * 2]; - } - groupSets_.numSets = 1; - groupSets_.sets = groupSet_; - } - - MockCuptiEvents* cuptiEvents_; - MockCuptiMetrics* cuptiMetrics_; - std::vector> loggers_; - std::vector> onDemandLoggers_; - constexpr static int kEventGroupCount = 4; - constexpr static int kGroupSetCount = 2; - CUpti_EventGroup eventGroups_[kEventGroupCount]; - CUpti_EventGroupSet groupSet_[kGroupSetCount]; - CUpti_EventGroupSets groupSets_; - std::unique_ptr profiler_; -}; - -TEST_F(EventProfilerTest, ConfigureFailure) { - using namespace testing; - - // Default config has no counters enabled. - // Check that profiler remains disabled. - Config cfg; - profiler_->configure(cfg, nullptr); - - EXPECT_FALSE(profiler_->enabled()); - - // There is no event named "cycles" - // In this case the profiler should print a warning and remain disabled - bool parsed = cfg.parse("EVENTS = cycles"); - EXPECT_TRUE(parsed); - - // EventProfiler should handle exception thrown from createGroupSets - // Configuration will be applied twice - once for combined base + on-demand - // and then again falling back to base - EXPECT_CALL(*cuptiEvents_, eventId("cycles")) - .Times(2) - .WillRepeatedly(Return(0)); - std::vector ids = {0}; - EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) - .Times(2) - .WillRepeatedly(Throw( - std::system_error(EINVAL, std::generic_category(), "Event ID"))); - profiler_->configure(cfg, nullptr); - - EXPECT_FALSE(profiler_->enabled()); -} - -TEST_F(EventProfilerTest, ConfigureBase) { - using namespace testing; - - // Test normal path, simple base config - Config cfg; - bool parsed = cfg.parse("EVENTS = elapsed_cycles_sm"); - EXPECT_TRUE(parsed); - - // One valid event - expect one call to eventId and createGroupSets - EXPECT_CALL(*cuptiEvents_, eventId("elapsed_cycles_sm")) - .Times(1) - .WillOnce(Return(5)); - std::vector ids = {5}; - EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) - .Times(1) - .WillOnce(Return(&groupSets_)); - EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[0])).Times(1); - EXPECT_CALL(*cuptiEvents_, instanceCount(eventGroups_[0])) - .Times(1) - .WillOnce(Return(80)); - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[0])) - .Times(1) - .WillOnce(Return(ids)); - EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); - - profiler_->configure(cfg, nullptr); - - EXPECT_TRUE(profiler_->enabled()); -} - -TEST_F(EventProfilerTest, ConfigureOnDemand) { - using namespace testing; - - // Test base + on-demand config, one event and one metric - Config cfg, on_demand_cfg; - bool parsed = cfg.parse(R"( - EVENTS = active_cycles - SAMPLE_PERIOD_MSECS=500 - REPORT_PERIOD_SECS=10 - SAMPLES_PER_REPORT=5 - )"); - EXPECT_TRUE(parsed); - - parsed = on_demand_cfg.parse(R"( - METRICS = ipc - EVENTS_DURATION_SECS=60 - SAMPLE_PERIOD_MSECS=200 - MULTIPLEX_PERIOD_MSECS=2000 - REPORT_PERIOD_SECS=3 - SAMPLES_PER_REPORT=10 - )"); - EXPECT_TRUE(parsed); - - // One event - EXPECT_CALL(*cuptiEvents_, eventId("active_cycles")) - .Times(1) - .WillOnce(Return(3)); - // One metric - EXPECT_CALL(*cuptiMetrics_, idFromName("ipc")).Times(1).WillOnce(Return(10)); - std::map ipc_events; - ipc_events[4] = "instructions"; - ipc_events[5] = "elapsed_cycles_sm"; - EXPECT_CALL(*cuptiMetrics_, events(10)).Times(1).WillOnce(Return(ipc_events)); - EXPECT_CALL(*cuptiMetrics_, evaluationMode(10)) - .Times(1) - .WillOnce(Return(CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE)); - EXPECT_CALL(*cuptiMetrics_, valueKind(10)) - .Times(1) - .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); - std::vector ids = {3, 4, 5}; - groupSet_[0].numEventGroups = 2; - groupSets_.numSets = 2; - EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) - .Times(1) - .WillOnce(Return(&groupSets_)); - // Specified CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE per instance above - // So check that it's enabled - EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[0])).Times(1); - EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[1])).Times(1); - EXPECT_CALL(*cuptiEvents_, enablePerInstance(eventGroups_[2])).Times(1); - std::vector ids_g1{3}, ids_g2{4}, ids_g3{5}; - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[0])) - .Times(1) - .WillOnce(Return(ids_g1)); - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[1])) - .Times(1) - .WillOnce(Return(ids_g2)); - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[2])) - .Times(1) - .WillOnce(Return(ids_g3)); - EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); - - profiler_->configure(cfg, &on_demand_cfg); - - EXPECT_TRUE(profiler_->enabled()); - EXPECT_EQ(profiler_->samplePeriod().count(), 250); - EXPECT_EQ(profiler_->multiplexPeriod().count(), 1000); - EXPECT_EQ(profiler_->reportPeriod().count(), 10000); - EXPECT_EQ(profiler_->onDemandReportPeriod().count(), 4000); -} - -TEST_F(EventProfilerTest, ReportSample) { - using namespace testing; - - // Test base + on-demand config, one event and one metric - Config cfg, on_demand_cfg; - bool parsed = cfg.parse("EVENTS = active_cycles"); - EXPECT_TRUE(parsed); - - parsed = on_demand_cfg.parse(R"( - METRICS = ipc - EVENTS_DURATION_SECS=60 - )"); - EXPECT_TRUE(parsed); - - // One event - EXPECT_CALL(*cuptiEvents_, eventId("active_cycles")) - .Times(1) - .WillOnce(Return(3)); - // One metric - EXPECT_CALL(*cuptiMetrics_, idFromName("ipc")).Times(1).WillOnce(Return(10)); - std::map ipc_events; - ipc_events[4] = "instructions"; - ipc_events[5] = "elapsed_cycles_sm"; - EXPECT_CALL(*cuptiMetrics_, events(10)).Times(1).WillOnce(Return(ipc_events)); - EXPECT_CALL(*cuptiMetrics_, evaluationMode(10)) - .Times(1) - .WillOnce(Return(CUPTI_METRIC_EVALUATION_MODE_PER_INSTANCE)); - EXPECT_CALL(*cuptiMetrics_, valueKind(10)) - .Times(1) - .WillOnce(Return(CUPTI_METRIC_VALUE_KIND_DOUBLE)); - std::vector ids = {3, 4, 5}; - groupSet_[0].numEventGroups = 2; - groupSets_.numSets = 2; - EXPECT_CALL(*cuptiEvents_, createGroupSets(ids)) - .Times(1) - .WillOnce(Return(&groupSets_)); - EXPECT_CALL(*cuptiEvents_, instanceCount(_)) - .Times(3) - .WillRepeatedly(Return(4)); - std::vector ids_g1{3}, ids_g2{4}, ids_g3{5}; - // These will be called by collectSample() as well, which is called twice - // per group set - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[0])) - .Times(3) - .WillRepeatedly(Return(ids_g1)); - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[1])) - .Times(3) - .WillRepeatedly(Return(ids_g2)); - EXPECT_CALL(*cuptiEvents_, eventsInGroup(eventGroups_[2])) - .Times(3) - .WillRepeatedly(Return(ids_g3)); - EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); - - profiler_->configure(cfg, &on_demand_cfg); - - EXPECT_TRUE(profiler_->enabled()); - - EXPECT_CALL(*cuptiEvents_, readEvent(_, _, _)) - .Times(6) - .WillRepeatedly(Invoke( - [](CUpti_EventGroup g, CUpti_EventID id, std::vector& vals) { - vals = {1, 2, 3, 4}; - })); - - // Need to collect four times - twice for each group set - profiler_->collectSample(); - profiler_->collectSample(); - EXPECT_CALL(*cuptiEvents_, disableGroupSet(_)).Times(1); - EXPECT_CALL(*cuptiEvents_, enableGroupSet(_)).Times(1); - profiler_->enableNextCounterSet(); - profiler_->collectSample(); - profiler_->collectSample(); - - std::vector ipc_ids = {4, 5}; - // Called once for each instance (4) and once for the total. - // x2 since we recompute per logger. - EXPECT_CALL( - *cuptiMetrics_, - calculate(10, CUPTI_METRIC_VALUE_KIND_DOUBLE, ipc_ids, _, 2000000000)) - .Times(10) - .WillRepeatedly(Return(SampleValue(0.3))); - auto& logger = dynamic_cast(*loggers_[0]); - EXPECT_CALL(logger, handleSample(0, _, _)) - .Times(1) - .WillOnce(Invoke([](int device, const Sample& sample, bool from_new_version) { - // Sample will include all stats - logger must pick the - // ones it wants. - EXPECT_EQ(sample.stats.size(), 4); - EXPECT_EQ(sample.stats[0].name, "active_cycles"); - EXPECT_EQ(sample.stats[1].name, "instructions"); - EXPECT_EQ(sample.stats[2].name, "elapsed_cycles_sm"); - EXPECT_EQ(sample.stats[3].name, "ipc"); - // 2 samples, each with values {1, 2, 3, 4} - // i.e. {2, 4, 6, 8} total - EXPECT_EQ(sample.stats[0].total.getInt(), 20); - EXPECT_EQ(sample.stats[0].percentileValues[0].second.getInt(), 2); - EXPECT_EQ(sample.stats[0].percentileValues.back().second.getInt(), 8); - // ipc is always 0.3 from mocked calculate function above - EXPECT_EQ(sample.stats[3].total.getDouble(), 0.3); - EXPECT_EQ(sample.stats[3].percentileValues[0].second.getDouble(), 0.3); - EXPECT_EQ( - sample.stats[3].percentileValues.back().second.getDouble(), 0.3); - })); - profiler_->reportSamples(); - - auto& on_demand_logger = dynamic_cast(*onDemandLoggers_[0]); - EXPECT_CALL(on_demand_logger, handleSample(0, _, _)).Times(1); - profiler_->reportOnDemandSamples(); - - EXPECT_CALL(*cuptiEvents_, disableGroupSet(_)).Times(1); -} diff --git a/plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp b/plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp deleted file mode 100644 index 30ba4a824af10401a45100b0b39cec54fcf98680..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp +++ /dev/null @@ -1,96 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include - -// TODO(T90238193) -// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include "include/libkineto.h" -#include "src/Logger.h" -#include "LoggerCollector.h" - -using namespace KINETO_NAMESPACE; - -#if !USE_GOOGLE_LOG - -constexpr char InfoTestStr[] = "Checking LOG(INFO)"; -constexpr char WarningTestStr[] = "Checking LOG(WARNING)"; -constexpr char ErrorTestStr[] = "Checking LOG(ERROR)"; - -TEST(LoggerObserverTest, SingleCollectorObserver) { - // Add a LoggerObserverCollector to collect all logs during the trace. - std::unique_ptr lCollector = std::make_unique(); - Logger::addLoggerObserver(lCollector.get()); - - LOG(INFO) << InfoTestStr; - LOG(WARNING) << WarningTestStr; - LOG(ERROR) << ErrorTestStr; - - auto LoggerMD = lCollector->extractCollectorMetadata(); - EXPECT_TRUE(LoggerMD[LoggerOutputType::INFO][0].find(InfoTestStr) != std::string::npos); - EXPECT_TRUE(LoggerMD[LoggerOutputType::WARNING][0].find(WarningTestStr) != std::string::npos); - EXPECT_TRUE(LoggerMD[LoggerOutputType::ERROR][0].find(ErrorTestStr) != std::string::npos); - - Logger::removeLoggerObserver(lCollector.get()); -} - -#define NUM_OF_MESSAGES_FOR_EACH_TYPE 10 -#define NUM_OF_WRITE_THREADS 200 - -// Writes NUM_OF_MESSAGES_FOR_EACH_TYPE messages for each INFO, WARNING, and ERROR. -// NOLINTNEXTLINE(clang-diagnostic-unused-parameter) -void* writeSeveralMessages(void* ptr) { - for(int i=0; i lc1 = std::make_unique(); - std::unique_ptr lc2 = std::make_unique(); - std::unique_ptr lc3 = std::make_unique(); - std::unique_ptr lc4 = std::make_unique(); - Logger::addLoggerObserver(lc1.get()); - Logger::addLoggerObserver(lc2.get()); - Logger::addLoggerObserver(lc3.get()); - Logger::addLoggerObserver(lc4.get()); - - // Launch NUM_OF_WRITE_THREADS threads writing several messages. - pthread_t ListOfThreads[NUM_OF_WRITE_THREADS]; - for (int i=0; iextractCollectorMetadata(); - int InfoCount = 0, WarnCount = 0, ErrorCount = 0; - for (auto& md : lc1MD) { - InfoCount += md.first == LoggerOutputType::INFO ? md.second.size() : 0; - WarnCount += md.first == LoggerOutputType::WARNING ? md.second.size() : 0; - ErrorCount += md.first == LoggerOutputType::ERROR ? md.second.size() : 0; - } - - EXPECT_EQ(InfoCount, NUM_OF_WRITE_THREADS * NUM_OF_MESSAGES_FOR_EACH_TYPE); - EXPECT_EQ(WarnCount, NUM_OF_WRITE_THREADS * NUM_OF_MESSAGES_FOR_EACH_TYPE); - EXPECT_EQ(ErrorCount, NUM_OF_WRITE_THREADS * NUM_OF_MESSAGES_FOR_EACH_TYPE); - - Logger::removeLoggerObserver(lc1.get()); - Logger::removeLoggerObserver(lc2.get()); - Logger::removeLoggerObserver(lc3.get()); - Logger::removeLoggerObserver(lc4.get()); -} - -#endif // !USE_GOOGLE_LOG - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp b/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp deleted file mode 100644 index 89f1d536ca8d6d794b7ffc7402001d0e3d4d9c06..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include -#include -#include - -#include "test/MockActivitySubProfiler.h" - -namespace libkineto { - -const std::set supported_activities {ActivityType::CPU_OP}; -const std::string profile_name{"MockProfiler"}; - -void MockProfilerSession::processTrace(ActivityLogger& logger) { - for (const auto& activity: activities()) { - activity.log(logger); - } -} - -const std::string& MockActivityProfiler::name() const { - return profile_name; -} - -const std::set& MockActivityProfiler::availableActivities() const { - return supported_activities; -} - -MockActivityProfiler::MockActivityProfiler( - std::vector& activities) : - test_activities_(activities) {}; - -std::unique_ptr MockActivityProfiler::configure( - const std::set& /*activity_types*/, - const Config& /*config*/) { - auto session = std::make_unique(); - session->set_test_activities(std::move(test_activities_)); - return session; -}; - -std::unique_ptr MockActivityProfiler::configure( - int64_t /*ts_ms*/, - int64_t /*duration_ms*/, - const std::set& activity_types, - const Config& config) { - return configure(activity_types, config); -}; - -} // namespace libkineto - diff --git a/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h b/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h deleted file mode 100644 index 36eaa13d1a544c624a2f4bb053891d055686ebf4..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h +++ /dev/null @@ -1,72 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#pragma once - -#include -#include -#include - -#include "include/IActivityProfiler.h" - -namespace libkineto { - -class MockProfilerSession: public IActivityProfilerSession { - - public: - explicit MockProfilerSession() {} - - void start() override { - start_count++; - status_ = TraceStatus::RECORDING; - } - - void stop() override { - stop_count++; - status_ = TraceStatus::PROCESSING; - } - - std::vector& activities() override { - return test_activities_; - } - - std::vector errors() override { - return {}; - } - - void processTrace(ActivityLogger& logger) override; - - void set_test_activities(std::vector&& acs) { - test_activities_ = std::move(acs); - } - - int start_count = 0; - int stop_count = 0; - private: - std::vector test_activities_; -}; - - -class MockActivityProfiler: public IActivityProfiler { - - public: - explicit MockActivityProfiler(std::vector& activities); - - const std::string& name() const override; - - const std::set& availableActivities() const override; - - std::unique_ptr configure( - const std::set& activity_types, - const Config& config) override; - - std::unique_ptr configure( - int64_t ts_ms, - int64_t duration_ms, - const std::set& activity_types, - const Config& config) override; - - private: - std::vector test_activities_; -}; - -} // namespace libkineto diff --git a/plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp b/plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp deleted file mode 100644 index b86cfb36d0581ba9a8a03a09724b181c2fd2e88a..0000000000000000000000000000000000000000 --- a/plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -#include "include/ThreadUtil.h" - -#include -#include - -#include -#include - -using namespace KINETO_NAMESPACE; - -TEST(ThreadNameTest, setAndGet) { - setThreadName("ThreadNameTest"); - EXPECT_EQ(getThreadName(), "ThreadNameTest"); - - setThreadName(""); - EXPECT_EQ(getThreadName(), ""); - - // Spaces etc are ok - setThreadName("Name w/ spaces"); - EXPECT_EQ(getThreadName(), "Name w/ spaces"); - - // More than 16 chars is not OK - setThreadName("More than 16 characters"); - EXPECT_EQ(getThreadName(), "Name w/ spaces"); -} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/.gitignore b/plugins/tensorboard-plugins/tb_graph_ascend/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..702406e615348632bfbb57fb513fa7ffbc67af1c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/.gitignore @@ -0,0 +1,8 @@ +node_modules/ +package-lock.json +.npmrc +yarn.lock +dist/ +build/ +tb_graph_ascend.egg-info/ +__pycache__/ \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/README.md b/plugins/tensorboard-plugins/tb_graph_ascend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f1cdbdd1fb93abd2907a717ff3c3e873faf2f5a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/README.md @@ -0,0 +1,200 @@ +# tb-graph-ascend + +### 介绍 + +此工具是将模型结构进行分级可视化展示的 Tensorboard 插件。可将模型的层级关系、精度性能数据进行可视化,并支持将采集数据和标杆模型进行关联比对,方便用户快速定位精度问题。 + +### 快速安装说明 + +- 相关依赖: + + `python >= 3.7 ,tensorboard >= 2.11.2,numpy <= 1.26.3` + +- 安装方式 + + 1. pip 安装(推荐) + + - 现本插件已经上传到 pypi 社区,用户可在 python 环境下直接通过以下 pip 指令进行安装: + ``` + pip install tb-graph-ascend + ``` + - 也可在 pypi 社区上下载离线 whl 包,传输到无法访问公网的环境上离线安装使用。访问[下载链接](https://pypi.org/project/tb-graph-ascend/#files)选择 whl 包进行下载,之后便可使用指令安装(此处{version}为 whl 包实际版本) + ``` + pip install tb-graph_ascend_{version}-py3-none-any.whl + ``` + + 2. 从源代码安装 + + - 从仓库下载源码并切换到 poc 分支: + + ``` + git clone https://gitee.com/ascend/mstt.git -b poc + ``` + + - 进入目录 `plugins/tensorboard-plugins/tb_graph_ascend` 下 + - 编译前端代码,根据操作系统选取不同指令 + + ``` + cd fe + // Windows系统 + npm run buildWin + // 其他可使用cp指令的系统,如Linux或Mac + npm run buildLinux + ``` + + **注意**: 此步骤需要安装 Node.js 环境 + + - 回到上级目录直接安装: + ``` + cd ../ + python setup.py develop + ``` + - 或: 构建 whl 包安装 + ``` + python setup.py bdist_wheel + ``` + 在 `plugins/tensorboard-plugins/tb_graph_ascend/dist` 目录下取出 whl 包,使用以下指令安装(此处{version}为 whl 包实际版本) + ``` + pip install tb-graph_ascend_{version}-py3-none-any.whl + ``` + +### 解析数据说明 + +- 准备数据 + + 将通过[msprobe](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe#10-%E5%88%86%E7%BA%A7%E5%8F%AF%E8%A7%86%E5%8C%96%E6%9E%84%E5%9B%BE%E6%AF%94%E5%AF%B9)工具采集得到的文件后缀为.vis 的模型结构文件(文件本身为 json 格式)放置于某个文件夹中,路径名称下文称之为 `out_path` \ + E.g. \ + `---output_path` \ + `-----output.vis` \ + `-----output2.vis` + +### 启动方式 + +1. 启动 TensorBoard + + ``` + tensorboard --logdir output_path + ``` + + 如果网络浏览器与启动 TensorBoard 的机器不在同一台机器上,则需要在尾部加上`--bind_all`命令,如: + + ``` + tensorboard --logdir output_path --bind_all + ``` + + 注意:确保默认端口 6006 对浏览器的主机打开。 + + 如果需要切换端口号需要在尾部加上指定的端口号,如`--port=6007` + + ``` + tensorboard --logdir output_path --port=6007 + ``` + +2. 在浏览器上打开 tensorboard + + 在浏览器中打开 URL: `http://localhost:6006`。 + 如果 tensorboard 启动命令使用`--bind_all`, 则需将主机名由`localhost`替换为主机的 ip 地址。 + + 注意:如果`--logdir` 指定目录下的文件太大或太多,请等候,刷新浏览器查看加载结果。 + +### 前端主体 + +#### 页面展示说明 + +① 设置栏 +② 图例 +③ 图结构 +④ 节点信息 +⑤ 缩略图 +![输入图片说明](./doc/images/mainPage.png) + +#### 操作方式: + +1. 节点双击打开,单击选中。 +2. 选中的节点边框呈现蓝色,比对场景下若其存在对应节点,则对应节点边框为浅蓝色 +3. 键盘 WS 根据鼠标位置放大缩小,AD 左右移动, +4. 鼠标滚轮上下移动,鼠标可拖动页面 +5. 比对场景鼠标右键可选中节点,并可展开至对应侧的节点并选中 + +#### 设置栏 + +##### 设置 + +![输入图片说明](./doc/images/setting.png) +| 字段 | 含义 | +|------|------| +| Fit to screen | 根据浏览器窗口的大小和比例,自动缩放图表,完整显示整个模型结构 | +| Download PNG | 以PNG的形式下载当前展开的模型结构 | +| Run | 模型结构文件所在目录 | +| Tag | 模型结构文件名称 | +| MicroStep | 展示模型结构文件某个MicroStep(若未按MicroStep采集则不会出现) | +| Step | 展示模型结构文件某个Step(若未按Step没有则不会出现) | + +##### 目录 + +1. 目录中的节点名称、层级信息、颜色、是否可以展开与图中的对应节点一致。 + +2. 若存在箭头则代表节点可展开,若不存在则代表其不存在子节点。 + +3. 单击目录中的节点时,若可以展开,则图中相应的节点会展开并被选中,否则仅选中。 + +4. 单击目录中已展开的节点时,目录和图中的结构都会收缩,若图中结构已经收缩,则仅选中。 + +5. 如果图中该节点已经展开,则仅会选中该节点,而不会收缩。 + +##### 搜索 + +| 字段 | 含义 | +| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------- | +| 节点搜索 | 搜索节点名称,右边有 NPU 侧和标杆侧(Bench)可以选择 | +| 精度筛选 | 勾选想要查看的颜色,可得到筛选后的节点列表,鼠标点击列表中节点名称,图自动展开到指定节点(仅展示叶子节点,展示顺序按 dump 时间序) | +| 标杆侧未匹配节点(比对场景特有) | 展示标杆侧未匹配节点,可通过列表或搜索上/下一个按钮来切换展示的节点 | + +##### 匹配(比对场景特有) + +| 字段 | 含义 | +| ---------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| 未匹配节点 | 会展示 NPU 侧和标杆侧的未匹配节点,点击列表节点会自动展开至对应层并选中该节点 | +| 匹配节点 | 根据数据模式(md5 模式/统计值模式/tensor 模式)对未匹配两侧的各一个节点进行匹配,若不能匹配则给出原因,若能匹配成功则该对节点移入已匹配节点中,匹配结果将在原始文件中持久化记录 | +| 已匹配节点 | 展示已匹配成功节点,点击列表节点会自动展开至对应层并选中该节点,并且会在另个列表中自动选中对应的节点 | +| 匹配列表 | 会显示已经匹配的两个节点的列表 | +| 取消匹配 | 将两个节点的所有属性还原至未匹配之前的状态,匹配列表中也会减去这两个节点 | + +#### 图例 + +##### 节点图例 + +| 图例名称 | 说明 | +| ------------------------------ | --------------------------------------------- | +| Module or Operators | 模块或者计算单元 | +| Unexpanded Module or Operators | 无法展开的模块或者计算单元 | +| Api List | 模块间游离的 API 的集合,只会在首层节点中出现 | + +##### 颜色图例: + +展示颜色及对应的精度区间,精度误差越大,越可疑,颜色越深 + +#### 图结构 + +图中结构为模型层级结构,为顺序执行序,顺序为从上到下,从左到右 +* 若为对比图,节点的颜色代表这个节点的精度误差 +* 节点单击选中,展示节点信息,若可展开双击展开,右键若为对比图可进行跳转功能 + +#### 节点信息 + +![输入图片说明](./doc/images/nodeInfo.png) \ +从上到下依次为: +| 信息名称 | 说明 | +|-------------|-----------------| +| 节点名称 | dump 下来的数据中节点的全名 | +| Subgraph | 子节点个数 | +| Suggestions | 对此节点的建议 | +| Attributes | 除输入输出外其他要展示的数据 | +| Inputs | 输入参数 | +| Outputs | 输出参数 | +| StackInfo | 堆栈信息 | + +#### 缩略图 + +展示整个模型计算图的缩略图和当前视窗所在位置,可以通过在缩率图中点击和拖动视窗快速查看模型的不同部分。 + diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/mainPage.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/mainPage.png new file mode 100644 index 0000000000000000000000000000000000000000..6084cd58db838864687999da2ad93fcccf0625de Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/mainPage.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/nodeInfo.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/nodeInfo.png new file mode 100644 index 0000000000000000000000000000000000000000..8653ecec1fc62eb39c425cbe7c9294f09c9db290 Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/nodeInfo.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/setting.png b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/setting.png new file mode 100644 index 0000000000000000000000000000000000000000..54cc998988b7dca765c3d9a772ea7832b4be65ea Binary files /dev/null and b/plugins/tensorboard-plugins/tb_graph_ascend/doc/images/setting.png differ diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/.prettierrc b/plugins/tensorboard-plugins/tb_graph_ascend/fe/.prettierrc new file mode 100644 index 0000000000000000000000000000000000000000..e3d2acb00457084b2f6cccafb8c95740e0344485 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/.prettierrc @@ -0,0 +1,13 @@ +{ + "parser": "typescript", + "semi": true, + "singleQuote": true, + "jsxSingleQuote": false, + "bracketSpacing": true, + "tabWidth": 2, + "useTabs": false, + "trailingComma": "all", + "proseWrap": "always", + "endOfLine": "lf", + "printWidth": 120 +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/index.html b/plugins/tensorboard-plugins/tb_graph_ascend/fe/index.html new file mode 100644 index 0000000000000000000000000000000000000000..06b36c3784738d1ec2632847c450896db04a4bbb --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/index.html @@ -0,0 +1,28 @@ + + + + + + + + + + + diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json new file mode 100644 index 0000000000000000000000000000000000000000..a5ceedfba48f3b6dd55f7a41cb7825fc62205a65 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/package.json @@ -0,0 +1,74 @@ +{ + "name": "tb-graph-ascend", + "version": "0.1.0", + "description": "#### Description {**When you're done, you can delete the content in this README and update the file with details for others getting started with your repository**}", + "main": "index.js", + "scripts": { + "buildLinux": "cross-env NODE_ENV=production webpack && cp dist/index.html ../server/static/", + "buildWin": "cross-env NODE_ENV=production webpack && copy dist\\index.html ..\\server\\static\\", + "dev": "webpack server", + "prettier": "prettier --config ./.prettierrc --write ./src/**/*.ts" + }, + "devDependencies": { + "@types/d3": "5.7.2", + "@types/lodash": "^4.14.172", + "@types/node": "^16.4.13", + "@types/offscreencanvas": "^2019.6.3", + "@types/requirejs": "^2.1.33", + "@types/resize-observer-browser": "^0.1.6", + "@types/three": "^0.131.0", + "html-loader": "^5.1.0", + "html-webpack-plugin": "^5.6.3", + "inline-chunk-html-plugin": "^1.1.1", + "ts-loader": "^9.5.1", + "tslib": "^2.6.2", + "typescript": "^5.4.5", + "webpack": "^5.96.1", + "webpack-cli": "^5.1.4" + }, + "dependencies": { + "@polymer/decorators": "^3.0.0", + "@polymer/iron-behaviors": "^3.0.1", + "@polymer/iron-collapse": "^3.0.1", + "@polymer/iron-flex-layout": "^3.0.1", + "@polymer/iron-icon": "^3.0.1", + "@polymer/iron-icons": "^3.0.1", + "@polymer/iron-iconset-svg": "^3.0.1", + "@polymer/iron-list": "^3.1.0", + "@polymer/iron-pages": "^3.0.1", + "@polymer/iron-resizable-behavior": "^3.0.1", + "@polymer/paper-behaviors": "^3.0.1", + "@polymer/paper-button": "^3.0.1", + "@polymer/paper-checkbox": "^3.1.0", + "@polymer/paper-dialog": "^3.0.1", + "@polymer/paper-dialog-scrollable": "^3.0.1", + "@polymer/paper-dropdown-menu": "^3.1.0", + "@polymer/paper-header-panel": "^3.0.1", + "@polymer/paper-icon-button": "^3.0.2", + "@polymer/paper-input": "^3.2.1", + "@polymer/paper-item": "^3.0.1", + "@polymer/paper-listbox": "^3.0.1", + "@polymer/paper-material": "^3.0.1", + "@polymer/paper-menu-button": "^3.0.1", + "@polymer/paper-progress": "^3.0.1", + "@polymer/paper-radio-button": "^3.0.1", + "@polymer/paper-radio-group": "^3.0.1", + "@polymer/paper-slider": "^3.0.1", + "@polymer/paper-spinner": "^3.0.2", + "@polymer/paper-styles": "^3.0.1", + "@polymer/paper-tabs": "^3.1.0", + "@polymer/paper-toast": "^3.0.1", + "@polymer/paper-toggle-button": "^3.0.1", + "@polymer/paper-toolbar": "^3.0.1", + "@polymer/paper-tooltip": "^3.0.1", + "@polymer/polymer": "^3.5.1", + "@types/lodash": "^4.17.1", + "clean-webpack-plugin": "^4.0.0", + "cross-env": "^7.0.3", + "css-loader": "^7.1.2", + "d3": "5.7.0", + "dagre": "^0.8.5", + "lodash": "^4.17.21", + "style-loader": "^4.0.0" + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css new file mode 100644 index 0000000000000000000000000000000000000000..a28ea1aebe630fe879bb4672047bc8bae03ec274 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.css @@ -0,0 +1,7 @@ +html, +body, +graph-app { + height: 100%; + margin: 0; + font-family: Roboto, sans-serif; +} \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..6e65aac4e05396ea7f9b2a749d16fc810d60a64e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/index.ts @@ -0,0 +1,18 @@ +/* Copyright (c) 2024, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import './tf_graph_dashboard/index'; +import './index.css'; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dark_mode_mixin.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dark_mode_mixin.ts new file mode 100644 index 0000000000000000000000000000000000000000..8a6ce19b69b08f348b56cd1a80a4ea602f69cb6a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dark_mode_mixin.ts @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { PolymerElement } from '@polymer/polymer'; + +/** + * Polymer mixin replacement for `:host-context(body.dark-mode)`. + *- + * Unfortunately, Firefox does not support `:host-context()` and cannot use the + * WebComponent way of styling shadow DOMs with context for ancestor [1][2]. + * To work around the issue, we are creating a WebComponent mixin that adds + * class `dark-mode` to `:host` when body contains the class, `.dark-mode`. + * + * Unfortunately, due to our infamiliarity with mixins, our types are imperfect. + * + * [1]: https://developer.mozilla.org/en-US/docs/Web/CSS/:host-context() + * [2]: https://bugzilla.mozilla.org/show_bug.cgi?id=1082060 + */ +export function DarkModeMixin( + Base: new () => PolymerElement +): new () => T { + return class Foo extends Base { + private observer?: MutationObserver; + + override connectedCallback() { + super.connectedCallback(); + this._maybeSetDarkMode(); + + this.observer = new MutationObserver((mutations) => { + const classChanged = mutations.some((mutation) => { + return mutation.attributeName === 'class'; + }); + if (classChanged) this._maybeSetDarkMode(); + }); + this.observer.observe(document.body, { attributes: true }); + } + + override disconnectedCallback() { + super.disconnectedCallback(); + this.observer?.disconnect(); + } + + private _maybeSetDarkMode() { + this.classList.toggle( + 'dark-mode', + document.body.classList.contains('dark-mode') + ); + } + } as unknown as new () => T; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dom-repeat.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dom-repeat.ts new file mode 100644 index 0000000000000000000000000000000000000000..53f21caa684fba44ef4bc35c9094340fd06cf18e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dom-repeat.ts @@ -0,0 +1,16 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export * from '@polymer/polymer/lib/elements/dom-repeat'; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dom.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dom.ts new file mode 100644 index 0000000000000000000000000000000000000000..7b1ace8b926fd768f63b9aa03c4163c05bacdb8a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/dom.ts @@ -0,0 +1,16 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export * from '@polymer/polymer/lib/legacy/polymer.dom'; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/irons_and_papers.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/irons_and_papers.ts new file mode 100644 index 0000000000000000000000000000000000000000..23b3fc88ad2fe9e570969044c79c96b77a55b1d0 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/irons_and_papers.ts @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * @fileoverview Imports all TensorBoard dependencies to paper and iron components. Please + * import this module for dependency on iron and paper components. + */ + +import '@polymer/iron-collapse/iron-collapse'; +import '@polymer/iron-flex-layout/iron-flex-layout-classes'; +import '@polymer/iron-icon'; +import '@polymer/iron-icons/image-icons'; +import '@polymer/iron-icons/iron-icons'; +import '@polymer/iron-iconset-svg'; +import '@polymer/iron-list/iron-list'; +import '@polymer/iron-pages'; +import '@polymer/paper-button'; +import '@polymer/paper-checkbox'; +import '@polymer/paper-dialog'; +import '@polymer/paper-dialog-scrollable'; +import '@polymer/paper-dropdown-menu/paper-dropdown-menu'; +import '@polymer/paper-header-panel'; +import '@polymer/paper-icon-button/paper-icon-button'; +import '@polymer/paper-input/paper-input'; +import '@polymer/paper-input/paper-textarea'; +import '@polymer/paper-item'; +import '@polymer/paper-listbox'; +import '@polymer/paper-material'; +import '@polymer/paper-menu-button'; +import '@polymer/paper-progress'; +import '@polymer/paper-radio-button'; +import '@polymer/paper-radio-group'; +import '@polymer/paper-slider'; +import '@polymer/paper-spinner/paper-spinner'; +import '@polymer/paper-spinner/paper-spinner-lite'; +import '@polymer/paper-styles/paper-styles'; +import '@polymer/paper-tabs'; +import '@polymer/paper-toast'; +import '@polymer/paper-toggle-button'; +import '@polymer/paper-toolbar'; +import '@polymer/paper-tooltip/paper-tooltip'; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/legacy_element_mixin.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/legacy_element_mixin.ts new file mode 100644 index 0000000000000000000000000000000000000000..e0768eac40066592e566d145cb6f9d1ed9054f31 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/legacy_element_mixin.ts @@ -0,0 +1,16 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export * from '@polymer/polymer/lib/legacy/legacy-element-mixin'; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/register_style_dom_module.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/register_style_dom_module.ts new file mode 100644 index 0000000000000000000000000000000000000000..6f3ee280ec8502bf58ee2977d1d3f253efd39972 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/polymer/register_style_dom_module.ts @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import '@polymer/polymer/lib/elements/dom-module'; + +export interface DomModuleOptions { + moduleName: string; + styleDependencies?: string[]; + styleContent: string; +} + +/** + * Interop for Polymer 3 styling + * + * From https://polymer-library.polymer-project.org/3.0/docs/devguide/style-shadow-dom: + * The following process is a workaround. While Polymer 3.0 does not use + * elements for templating, style modules do. The following process + * is a workaround for this fact. This process may be updated as required. + */ +export function registerStyleDomModule(args: DomModuleOptions): void { + const {moduleName, styleContent} = args; + const domModule = document.createElement('dom-module'); + const template = document.createElement('template'); + + const styleIncludes: HTMLStyleElement[] = []; + if (args.styleDependencies) { + args.styleDependencies.forEach((dep) => { + const style = document.createElement('style'); + style.setAttribute('include', dep); + styleIncludes.push(style); + }); + } + const style = document.createElement('style'); + Object.assign(style, {textContent: styleContent}); + + styleIncludes.forEach((styleElement) => { + template.content.appendChild(styleElement); + }); + template.content.appendChild(style); + domModule.appendChild(template); + (domModule as any).register(moduleName); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tb_debug/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tb_debug/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..8127393ed4d8e012f118b7657a3ae5b7e8448f40 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tb_debug/index.ts @@ -0,0 +1,26 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {ActionEvent} from './types'; + +export * from './types'; + +/** + * A method used only by Polymer code for notifying TensorBoard developers + * locally about events for debugging purposes. Do not use these events as + * lifecycle hooks in production. + * + * It is intentionally a no-op. There is no usage tracking in TensorBoard. + */ +export function notifyActionEventFromPolymer(actionEvent: ActionEvent) {} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tb_debug/types.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tb_debug/types.ts new file mode 100644 index 0000000000000000000000000000000000000000..f3b3687d91c97357393d30354ec66a245255cd0a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tb_debug/types.ts @@ -0,0 +1,78 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +export interface ActionEvent { + eventCategory: string; + eventAction: string; + eventLabel?: string; + eventValue?: number; +} + +export const GRAPH_DEBUG_ACTION_EVENT_CATEGORY = 'Graph dashboard actions'; +export const GRAPH_DEBUG_TIMING_EVENT_CATEGORY = 'Graph dashboard timings'; + +/** + * Timing-based events, part of `GRAPH_DEBUG_TIMING_EVENT_CATEGORY`. + */ +export enum GraphDebugTimingEventId { + // Pre-rendering. + // `FETCH_PBTXT_BYTES` is fired for both filesystem and server sources. + FETCH_PBTXT_BYTES = 'FETCH_PBTXT_BYTES', + FETCH_PBTXT_BYTES_FROM_FILESYSTEM = 'FETCH_PBTXT_BYTES_FROM_FILESYSTEM', + FETCH_PBTXT_BYTES_FROM_SERVER = 'FETCH_PBTXT_BYTES_FROM_SERVER', + PARSE_PBTXT_INTO_OBJECT = 'PARSE_PBTXT_INTO_OBJECT', + FETCH_METADATA_PBTXT_BYTES = 'FETCH_METADATA_PBTXT_BYTES', + PARSE_METADATA_PBTXT_INTO_OBJECT = 'PARSE_METADATA_PBTXT_INTO_OBJECT', + NORMALIZING_NAMES = 'NORMALIZING_NAMES', + BUILD_SLIM_GRAPH = 'BUILD_SLIM_GRAPH', + HIERARCHY_ADD_NODES = 'HIERARCHY_ADD_NODES', + HIERARCHY_DETECT_SERIES = 'HIERARCHY_DETECT_SERIES', + HIERARCHY_ADD_EDGES = 'HIERARCHY_ADD_EDGES', + HIERARCHY_FIND_SIMILAR_SUBGRAPHS = 'HIERARCHY_FIND_SIMILAR_SUBGRAPHS', + // Rendering. + RENDER_BUILD_HIERARCHY = 'RENDER_BUILD_HIERARCHY', + RENDER_SCENE_LAYOUT = 'RENDER_SCENE_LAYOUT', + RENDER_SCENE_BUILD_SCENE = 'RENDER_SCENE_BUILD_SCENE', + // Total graph loading (superset of other phases). Note that after [1], + // this timing no longer includes `HIERARCHY_FIND_SIMILAR_SUBGRAPHS`, + // which is computed lazily. + // [1] https://github.com/tensorflow/tensorboard/pull/4742 + GRAPH_LOAD_SUCCEEDED = 'GRAPH_LOAD_SUCCEEDED', + GRAPH_LOAD_FAILED = 'GRAPH_LOAD_FAILED', +} + +/** + * Non-timing based actions due to user interaction, part of + * `GRAPH_DEBUG_ACTION_EVENT_CATEGORY`. + */ +export enum GraphDebugActionEventId { + // Labeled by state: expanded or collapsed. + NODE_EXPANSION_TOGGLED = 'NODE_EXPANSION_TOGGLED', + NODE_SEARCH_RESULT_FOCUSED = 'NODE_SEARCH_RESULT_FOCUSED', + // Labeled by direction between auxiliary graph and the main graph. + NODE_AUXILIARY_EXTRACTION_CHANGED = 'NODE_AUXILIARY_EXTRACTION_CHANGED', + // Labeled by graph type: Op, Conceptual, Profile. + GRAPH_TYPE_CHANGED = 'GRAPH_TYPE_CHANGED', + TRACE_INPUT_MODE_TOGGLED = 'TRACE_INPUT_MODE_TOGGLED', + // Labeled by mode: Structure, Device, TPU Compat, etc. + NODE_COLOR_MODE_CHANGED = 'NODE_COLOR_MODE_CHANGED', + UPLOADED_GRAPH_FROM_FILESYSTEM = 'UPLOADED_GRAPH_FROM_FILESYSTEM', +} + +// Merge the string enums. +export const GraphDebugEventId = { + ...GraphDebugTimingEventId, + ...GraphDebugActionEventId, +}; +export type GraphDebugEventId = typeof GraphDebugEventId; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/canceller.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/canceller.ts new file mode 100644 index 0000000000000000000000000000000000000000..9510edd5473f31d6cd9bc49d56107dd21e7a9ba3 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/canceller.ts @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export interface CancelResult { + value: T; + cancelled: boolean; +} + +/** + * A class that allows marking promises as cancelled. + * + * This can be useful to, e.g., prevent old network requests from + * stomping new ones and writing bad data. + * + * Usage: + * + * const canceller = new Canceller(); + * let myPromise: Promise = getPromise(); + * myPromise.then(canceller.cancellable(({value, cancelled} => { + * if (cancelled) { + * console.warn("Don't make promises you can't keep >:-{"); + * } + * console.log("Enjoy your value:", value); + * })); + * + * // If `myPromise` is resolved now, then `cancelled` will be `false`. + * canceller.cancelAll(); + * // If `myPromise` is resolved now, then `cancelled` will be `true`. + */ +export class Canceller { + /** + * How many times has `cancelAll` been called? + */ + private cancellationCount = 0; + /** + * Create a cancellable task. This returns a new function that, when + * invoked, will pass its argument to the provided function as well as + * a `cancelled` argument. This argument will be `false` unless and + * until `cancelAll` is invoked after the creation of this task. + */ + public cancellable(f: (result: CancelResult) => U): (T) => U { + const originalCancellationCount = this.cancellationCount; + return (value) => { + const cancelled = this.cancellationCount !== originalCancellationCount; + return f({ value, cancelled }); + }; + } + /** + * Mark all outstanding tasks as cancelled. Tasks not yet created will + * not be affected. + */ + public cancelAll(): void { + this.cancellationCount++; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/requestManager.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/requestManager.ts new file mode 100644 index 0000000000000000000000000000000000000000..9b89a769645f6d75baf1bd20710c4f14df6a400b --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/requestManager.ts @@ -0,0 +1,306 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +const FEATURE_FLAGS_HEADER_NAME = 'X-TensorBoard-Feature-Flags'; + +interface ResolveReject { + resolve: (x: unknown) => void; + reject: (x: unknown) => void; +} + +/** + * Manages many fetch requests. Launches up to nSimultaneousRequests + * simultaneously, and maintains a LIFO queue of requests to process when + * more urls are requested than can be handled at once. The queue can be + * cleared. + * + * When a request is made, a Promise is returned which resolves with the + * parsed JSON result from the request. + */ +export class RequestCancellationError extends Error { + public override name = 'RequestCancellationError'; +} + +export class InvalidRequestOptionsError extends Error { + public override name = 'InvalidRequestOptionsError'; + constructor(msg: string) { + super(msg); + // The following is needed due to a limitation of TypeScript when + // extending 'Error'. See: https://github.com/Microsoft/TypeScript/wiki/Breaking-Changes#extending-built-ins-like-error-array-and-map-may-no-longer-work + Object.setPrototypeOf(this, InvalidRequestOptionsError.prototype); + } +} + +export class RequestNetworkError extends Error { + public override name: string; + public req: XMLHttpRequest; + public url: string; + constructor(req: XMLHttpRequest, url) { + super(); + this.message = `RequestNetworkError: ${req.status} at ${url}`; + this.name = 'RequestNetworkError'; + this.req = req; + this.url = url; + } +} + +/** The HTTP method-type to use. Currently only 'GET' and 'POST' are + * supported. + */ +export enum HttpMethodType { + GET = 'GET', + POST = 'POST', +} + +/** + * Holds options that can be used to configure the HTTP request. + */ +export class RequestOptions { + public methodType: HttpMethodType; + /** The content-type request header to use. Cannot be set for a GET request.*/ + public contentType?: string; + /** The request body to use. This is the object that is passed to the + * XMLHttpRequest.send() method. If not given the 'send' method is called + * without an argument. + */ + public body?: any; + /** If specified, this will be the value set in the + * XMLHttpRequest.withCredentials property. + */ + public withCredentials?: boolean; + // Validates this object. Throws InvalidRequestOptionsError on error. + public validate() { + if (this.methodType === HttpMethodType.GET) { + // We don't allow a body for a GET. + if (this.body) { + throw new InvalidRequestOptionsError( + 'body must be missing for a GET request.' + ); + } + } + // We allow body-less or contentType-less POSTs even if they don't + // make much sense. + } +} + +// Form data for a POST request as a convenient multidict interface, +// since the built-in `FormData` type doesn't have a value constructor. +// +// A raw string value is equivalent to a singleton array, and thus an +// empty array value is equivalent to omitting the key entirely. +export interface PostData { + [key: string]: string | string[]; +} + +export class RequestManager { + private _queue: ResolveReject[]; + private _maxRetries: number; + private _nActiveRequests: number; + private _nSimultaneousRequests: number; + constructor(nSimultaneousRequests = 1000, maxRetries = 3) { + this._queue = []; + this._nActiveRequests = 0; + this._nSimultaneousRequests = nSimultaneousRequests; + this._maxRetries = maxRetries; + } + /** + * Gives a promise that loads assets from given url (respects queuing). If + * postData is provided, this request will use POST, not GET. This is an + * object mapping POST keys to string values. + */ + public request(url: string, postData?: PostData): Promise { + const requestOptions = requestOptionsFromPostData(postData); + return this.requestWithOptions(url, requestOptions); + } + public requestWithOptions( + url: string, + requestOptions: RequestOptions + ): Promise { + requestOptions.validate(); + const promise = new Promise((resolve, reject) => { + const resolver = { resolve: resolve, reject: reject }; + this._queue.push(resolver); + this.launchRequests(); + }) + .then(() => { + return this.promiseWithRetries(url, this._maxRetries, requestOptions); + }) + .then( + (response) => { + // Success - Let's free space for another active + // request, and launch it + this._nActiveRequests--; + this.launchRequests(); + return response; + }, + (rejection) => { + if (rejection.name === 'RequestNetworkError') { + // If we failed due to network error, we should + // decrement + // _nActiveRequests because this request was + // active + this._nActiveRequests--; + this.launchRequests(); + } + return Promise.reject(rejection); + } + ); + return promise; + } + public fetch(url: string, fetchOptions?: RequestInit): Promise { + return new Promise((resolve, reject) => { + const resolver = { resolve: resolve, reject: reject }; + this._queue.push(resolver); + this.launchRequests(); + }).then(() => { + let numTries = 1; + return new Promise((resolve) => { + const retryFetch = () => { + fetch(url, fetchOptions).then((response) => { + if (!response.ok && this._maxRetries > numTries) { + numTries++; + retryFetch(); + return; + } + resolve(response); + this._nActiveRequests--; + this.launchRequests(); + }); + }; + retryFetch(); + }); + }); + } + public clearQueue() { + while (this._queue.length > 0) { + this._queue + .pop() + ?.reject( + new RequestCancellationError('Request cancelled by clearQueue') + ); + } + } + + /* Return number of currently pending requests */ + public activeRequests(): number { + return this._nActiveRequests; + } + /* Return total number of outstanding requests (includes queue) */ + public outstandingRequests(): number { + return this._nActiveRequests + this._queue.length; + } + private launchRequests() { + while ( + this._nActiveRequests < this._nSimultaneousRequests && + this._queue.length > 0 + ) { + this._nActiveRequests++; + this._queue.pop()!.resolve(undefined); + } + } + /** + * Try to request a given URL using overwritable _promiseFromUrl method. + * If the request fails for any reason, we will retry up to maxRetries + * times. In practice, this will help us paper over transient network issues + * like '502 Bad Gateway'. + * By default, Chrome displays network errors in console, so + * the user will be able to tell when the requests are failing. I think this + * is a feature, if the request failures and retries are causing any + * pain to users, they can see it and file issues. + */ + private promiseWithRetries( + url: string, + maxRetries: number, + requestOptions: RequestOptions + ) { + var success = (x) => x; + var failure = (x) => { + if (maxRetries > 0) { + return this.promiseWithRetries(url, maxRetries - 1, requestOptions); + } else { + return Promise.reject(x); + } + }; + return this._promiseFromUrl(url, requestOptions).then(success, failure); + } + /* Actually get promise from url using XMLHttpRequest */ + protected _promiseFromUrl(url: string, requestOptions: RequestOptions) { + return new Promise((resolve, reject) => { + const req = buildXMLHttpRequest( + requestOptions.methodType, + url, + requestOptions.withCredentials, + requestOptions.contentType + ); + req.setRequestHeader( + FEATURE_FLAGS_HEADER_NAME, + JSON.stringify({}) + ); + req.onload = function () { + if (req.status === 200) { + resolve(JSON.parse(req.responseText) as any); + } else { + reject(new RequestNetworkError(req, url)); + } + }; + req.onerror = function () { + reject(new RequestNetworkError(req, url)); + }; + if (requestOptions.body) { + req.send(requestOptions.body); + } else { + req.send(); + } + }); + } +} + +function buildXMLHttpRequest( + methodType: HttpMethodType, + url: string, + withCredentials?: boolean, + contentType?: string +): XMLHttpRequest { + const req = new XMLHttpRequest(); + req.open(methodType, url); + if (withCredentials) { + req.withCredentials = withCredentials; + } + if (contentType) { + req.setRequestHeader('Content-Type', contentType); + } + return req; +} + +function requestOptionsFromPostData(postData?: PostData): RequestOptions { + const result = new RequestOptions(); + if (!postData) { + result.methodType = HttpMethodType.GET; + return result; + } + result.methodType = HttpMethodType.POST; + result.body = formDataFromDictionary(postData); + return result; +} + +function formDataFromDictionary(postData: PostData) { + const formData = new FormData(); + for (const [key, maybeValues] of Object.entries(postData)) { + const values = Array.isArray(maybeValues) ? maybeValues : [maybeValues]; + for (const value of values) { + formData.append(key, value); + } + } + return formData; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/router.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/router.ts new file mode 100644 index 0000000000000000000000000000000000000000..082939622db295cc0438c0f406a57d2907a38045 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_backend/router.ts @@ -0,0 +1,79 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +export type QueryParams = { + [key: string]: string | string[]; +}; + +const EXPERIMENTAL_PLUGINS_QUERY_PARAM = 'experimentalPlugin'; + +export interface Router { + pluginRouteForSrc: (pluginName: string, route: string, params?: URLSearchParams) => string; + pluginsListing: () => string; +} + +/** + * Save the initial URL query params, before the AppRoutingEffects initialize, + * and before creating the router. + */ +const initialURLSearchParams = new URLSearchParams(window.location.search); +let _router: Router = createRouter(); + +/** + * Create a router for communicating with the TensorBoard backend. You + * can pass this to `setRouter` to make it the global router. + */ +export function createRouter(): Router { + return { + pluginRouteForSrc: (pluginName: string, route: string, params: URLSearchParams = new URLSearchParams()): string => { + return createDataPath(`/plugin/${pluginName}${route}`, params); + }, + pluginsListing: () => + createDataPath( + '/plugins_listing', + createSearchParam({ + [EXPERIMENTAL_PLUGINS_QUERY_PARAM]: initialURLSearchParams.getAll(EXPERIMENTAL_PLUGINS_QUERY_PARAM), + }), + ), + }; +} + +/** + * @return {Router} the global router + */ +export function getRouter(): Router { + return _router; +} + +function createDataPath(route: string, params: URLSearchParams = new URLSearchParams()): string { + let relativePath = '/data' + route; + if (String(params)) { + const delimiter = route.includes('?') ? '&' : '?'; + relativePath += delimiter + String(params); + } + return relativePath; +} + +export function createSearchParam(params: QueryParams = {}): URLSearchParams { + const keys = Object.keys(params) + .sort() + .filter((k) => params[k]); + const searchParams = new URLSearchParams(); + keys.forEach((key) => { + const values = params[key]; + const array = Array.isArray(values) ? values : [values]; + array.forEach((val) => searchParams.append(key, val)); + }); + return searchParams; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/scrollbar-style.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/scrollbar-style.ts new file mode 100644 index 0000000000000000000000000000000000000000..1793c1285793c5123e2c043c6150661b625bf21f --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/scrollbar-style.ts @@ -0,0 +1,39 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {registerStyleDomModule} from '../polymer/register_style_dom_module'; + +registerStyleDomModule({ + moduleName: 'scrollbar-style', + styleContent: ` + .scrollbar::-webkit-scrollbar-track { + visibility: hidden; + } + + .scrollbar::-webkit-scrollbar { + width: 10px; + } + + .scrollbar::-webkit-scrollbar-thumb { + border-radius: 10px; + -webkit-box-shadow: inset 0 0 2px rgba(0, 0, 0, 0.3); + background-color: var(--paper-grey-500); + color: var(--paper-grey-900); + } + .scrollbar { + box-sizing: border-box; + } + `, +}); diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/tensorboard-color.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/tensorboard-color.ts new file mode 100644 index 0000000000000000000000000000000000000000..60dbb0a43458df7035522a2627f463e44c301448 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/tensorboard-color.ts @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +const style = document.createElement('style'); +style.setAttribute('is', 'custom-style'); +style.textContent = ` + :root { + --tb-orange-weak: #ffa726; + --tb-orange-strong: #f57c00; + --tb-orange-dark: #dc7320; + --tb-grey-darker: #e2e2e2; + --tb-grey-lighter: #f3f3f3; + --tb-ui-dark-accent: #757575; + --tb-ui-light-accent: #e0e0e0; + --tb-ui-border: var(--paper-grey-300); + --tb-graph-faded: #e0d4b3; + --tb-secondary-text-color: var(--paper-grey-800); + --tb-raised-button-shadow-color: rgba(0, 0, 0, 0.2); + --primary-background-color: #fff; + --secondary-background-color: #e9e9e9; + --tb-layout-background-color: #f5f5f5; + --tb-link: #1976d2; /* material blue 700. */ + --tb-link-visited: #7b1fa2; /* material purple 700. */ + } + + :root .dark-mode { + --tb-ui-border: var(--paper-grey-700); + --tb-ui-dark-accent: var(--paper-grey-400); + --tb-ui-light-accent: var(--paper-grey-600); + --tb-secondary-text-color: var(--paper-grey-400); + --tb-raised-button-shadow-color: rgba(255, 255, 255, 0.5); + --primary-text-color: #fff; + --secondary-text-color: var(--paper-grey-400); + --primary-background-color: #303030; /* material grey A400. */ + --secondary-background-color: #3a3a3a; + --tb-layout-background-color: #3a3a3a; + --tb-link: #42a5f5; /* material blue 400. */ + --tb-link-visited: #ba68c8; /* material purple 300. */ + /* Overrides paper-material */ + --shadow-elevation-2dp_-_box-shadow: 0 2px 2px 0 rgba(255, 255, 255, 0.14), + 0 1px 5px 0 rgba(255, 255, 255, 0.12), + 0 3px 1px -2px rgba(255, 255, 255, 0.2); + } +`; +document.head.appendChild(style); diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/tf-dashboard-layout.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/tf-dashboard-layout.ts new file mode 100644 index 0000000000000000000000000000000000000000..51b5491c781f8d7c523f682c9421b247524db38a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_dashboard_common/tf-dashboard-layout.ts @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {customElement} from '@polymer/decorators'; +import {html, PolymerElement} from '@polymer/polymer'; +import {DarkModeMixin} from '../polymer/dark_mode_mixin'; +import './scrollbar-style'; +import './tensorboard-color'; + +@customElement('tf-dashboard-layout') +class TfDashboardLayout extends DarkModeMixin(PolymerElement) { + static readonly template = html` + + +
+ +
+ + + `; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_globals/globals.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_globals/globals.ts new file mode 100644 index 0000000000000000000000000000000000000000..baee926c653d484e311facdbb84ce2bc78e0e6d1 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_globals/globals.ts @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// If true, TensorBoard stores its hash in the URI state. +// If false, tab switching in TensorBoard will not update location hash, +// because hash updates interfere with wct_tests. +let _useHash = false; + +export function setUseHash(shouldUseHash: boolean): void { + _useHash = shouldUseHash; +} + +export function useHash(): boolean { + return _useHash; +} + +let _fakeHash = ''; + +export function setFakeHash(h: string) { + _fakeHash = h; +} + +export function getFakeHash() { + return _fakeHash; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-minimap.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-minimap.ts new file mode 100644 index 0000000000000000000000000000000000000000..5725b417c0e429a6314054eaed4f4015030101fe --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-minimap.ts @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {customElement} from '@polymer/decorators'; +import {html, PolymerElement} from '@polymer/polymer'; +import * as tf_scene_minimap from '../tf_graph_common/minimap'; + +@customElement('tf-graph-minimap') +export class TfGraphMinimap extends PolymerElement { + static readonly template = html` + + + + + + + + + + + + + + + + + `; + /** + * Initializes the minimap and returns a minimap object to notify when + * things update. + * + * @param svg The main svg element. + * @param zoomG The svg group used for panning and zooming the main svg. + * @param mainZoom The main zoom behavior. + * @param maxWAndH The maximum width/height for the minimap. + * @param labelPadding Padding in pixels due to the main graph labels. + */ + init(svg, zoomG, mainZoom, maxWAndH, labelPadding) { + return new tf_scene_minimap.Minimap( + svg, + zoomG, + mainZoom, + this, + maxWAndH, + labelPadding + ); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts new file mode 100644 index 0000000000000000000000000000000000000000..b8250e300a67b5e5abf409fb54a0a7315024699e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.html.ts @@ -0,0 +1,769 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import { html } from '@polymer/polymer'; + +// Please keep node font-size/classnames in sync with tf-graph-common/common.ts +export const template = html` + +
+
Main Graph
+
Auxiliary Nodes
+
Functions
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+`; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.ts new file mode 100644 index 0000000000000000000000000000000000000000..f140c2ad52d9c7b3a888ead7044ee062b1288ab7 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph-scene.ts @@ -0,0 +1,740 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, observe, property } from '@polymer/decorators'; +import { PolymerElement } from '@polymer/polymer'; +import * as d3 from 'd3'; +import * as _ from 'lodash'; +import { DarkModeMixin } from '../polymer/dark_mode_mixin'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tb_debug from '../tb_debug'; +import '../tf_dashboard_common/tensorboard-color'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_layout from '../tf_graph_common/layout'; +import * as tf_graph_minimap from '../tf_graph_common/minimap'; +import * as tf_graph_scene_node from '../tf_graph_common/node'; +import * as tf_graph_render from '../tf_graph_common/render'; +import * as tf_graph_scene from '../tf_graph_common/scene'; +import { TfGraphScene } from '../tf_graph_common/tf-graph-scene'; +import * as tf_graph_util from '../tf_graph_common/util'; +import './tf-graph-minimap'; +import { template } from './tf-graph-scene.html'; + +@customElement('tf-graph-scene') +class TfGraphScene2 extends LegacyElementMixin(DarkModeMixin(PolymerElement)) implements TfGraphScene { + static readonly template = template; + @property({ type: Number }) + _step: number = 20; + @property({ type: Number }) + _scaleStep: number = 1; + @property({ type: Number }) + mouseX: number = 0; + @property({ type: Number }) + mouseY: number = 0; + @property({ type: Number }) + x: number = 0; + @property({ type: Number }) + y: number = 0; + @property({ type: Object }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + @property({ type: String }) + name: string; + @property({ type: Boolean }) + traceInputs: boolean; + + // For each render hierarchy, we only fit it to the viewport once (when the scene is attached to + // the DOM). We do not fit the hierarchy again (unless the user clicks the reset button). For + // instance, if the user enters a certain view in the graph, switches to another dashboard, and + // returns to the graph dashboard, the user expects the previous view. These properties enable + // that behavior. + + /** Whether the scene has fit the current render hierarchy (to the viewport) at least once. */ + @property({ type: Boolean }) + _hasRenderHierarchyBeenFitOnce: boolean; + + /** Whether this scene element is currently attached to a parent element. */ + @property({ type: Boolean }) + _isAttached: boolean; + + /** This property is a d3_zoom object. */ + @property({ type: Object }) + _zoom: object; + + /** This property is a d3_drag object. */ + @property({ type: Object }) + _drag: object; + + @property({ + type: String, + observer: '_highlightedNodeChanged', + }) + highlightedNode: string; + @property({ + type: String, + observer: '_selectedNodeChanged', + }) + selectedNode: string; + @property({ + type: String, + observer: '_linkedNodeChanged', + }) + linkedNode: string; + + // An optional callback that implements the tf.graph.edge.EdgeSelectionCallback signature. If + // provided, edges are selectable, and this callback is run when an edge is selected. + @property({ type: Object }) + handleEdgeSelected: object; + + /** Keeps track of if the graph has been zoomed/panned since loading */ + @property({ + type: Boolean, + observer: '_onZoomChanged', + }) + _zoomed: boolean = false; + + /** + * Keeps track of the starting coordinates of a graph zoom/pan. + * + * @private {{x: number, y: number}?} + */ + @property({ + type: Object, + }) + _zoomStartCoords: object | null = null; + + /** + * Keeps track of the current coordinates of a graph zoom/pan + * + * @private {{x: number, y: number}?} + */ + @property({ + type: Object, + }) + _zoomTransform: object | null = null; + + /** Maximum distance of a zoom event for it to be interpreted as a click */ + @property({ + type: Number, + }) + _maxZoomDistanceForClick: number = 20; + + /** + * Scale mapping from template name to a number between 0 and N-1 + * where N is the number of different template names. Used by + * tf_graph_scene_node when computing node color by structure. + * This property is a d3.scale.ordinal object. + */ + @property({ type: Object }) templateIndex: (name: string) => number | null; + + /** + * A minimap object to notify for zoom events. + */ + private minimap: tf_graph_minimap.Minimap; + + /* + * Dictionary for easily stylizing nodes when state changes. + * _nodeGroupIndex[nodeName] = d3_selection of the nodeGroup + */ + @property({ + type: Object, + }) + _nodeGroupIndex = {}; + + /* + * Dictionary for easily stylizing annotation nodes when state changes. + * _annotationGroupIndex[nodeName][hostNodeName] = + * d3_selection of the annotationGroup + */ + @property({ + type: Object, + }) + _annotationGroupIndex = {}; + + /* + * Dictionary for easily stylizing edges when state changes. + * _edgeGroupIndex[edgeName] = d3_selection of the edgeGroup + */ + @property({ + type: Object, + }) + _edgeGroupIndex = {}; + + /** + * Max font size for metanode label strings. + */ + @property({ + type: Number, + }) + maxMetanodeLabelLengthFontSize: number = 9; + + /** + * Min font size for metanode label strings. + */ + @property({ + type: Number, + }) + minMetanodeLabelLengthFontSize: number = 6; + + /** + * Metanode label strings longer than this are given smaller fonts. + */ + @property({ + type: Number, + }) + maxMetanodeLabelLengthLargeFont: number = 11; + + /** + * Metanode label strings longer than this are truncated with ellipses. + */ + @property({ + type: Number, + }) + maxMetanodeLabelLength: number = 50; + @property({ type: Object }) + progress: any; + + // An array of ContextMenuItem objects. Items that appear in the context + // menu for a node. + @property({ type: Array }) + nodeContextMenuItems: unknown[]; + getNode(nodeName) { + return this.renderHierarchy.getRenderNodeByName(nodeName); + } + isNodeExpanded(node) { + return node.expanded; + } + setNodeExpanded(renderNode) { + this._build(this.renderHierarchy); + this._updateLabels(!this._zoomed); + } + /** + * Pans to a node. Assumes that the node exists. + * @param nodeName {string} The name of the node to pan to. + */ + panToNode(nodeName) { + const zoomed = tf_graph_scene.panToNode(nodeName, this.$.svg, this.$.root, this._zoom); + if (zoomed) { + this._zoomed = true; + } + } + /** + * Returns the outer-most SVG that renders the graph. + */ + getGraphSvgRoot(): SVGElement { + return this.$.svg as SVGElement; + } + getContextMenu(): HTMLElement { + return this.$.contextMenu as HTMLElement; + } + /** + * Resets the state of the component. Called whenever the whole graph + * (dataset) changes. + */ + _resetState() { + // Reset the state of the component. + this._nodeGroupIndex = {}; + this._annotationGroupIndex = {}; + this._edgeGroupIndex = {}; + this._updateLabels(false); + // Remove all svg elements under the 'root' svg group. + d3.select(this.$.svg).select('#root').selectAll('*').remove(); + // And the defs. + tf_graph_scene_node.removeGradientDefinitions(this.$.svg as SVGElement); + } + /** Main method for building the scene */ + _build(renderHierarchy: tf_graph_render.RenderGraphInfo) { + if (!renderHierarchy) { + return; + } + this.templateIndex = renderHierarchy.hierarchy.getTemplateIndex(); + tf_graph_util.time( + 'tf-graph-scene (layout):', + function () { + // layout the scene for this meta / series node + tf_graph_layout.layoutScene(renderHierarchy.root); + }.bind(this), + tb_debug.GraphDebugEventId.RENDER_SCENE_LAYOUT, + ); + tf_graph_util.time( + 'tf-graph-scene (build scene):', + function () { + tf_graph_scene_node.buildGroupForScene(d3.select(this.$.root), renderHierarchy.root, this); + tf_graph_scene.addGraphClickListener(this.$.svg, this); + this._updateInputTrace(); + }.bind(this), + tb_debug.GraphDebugEventId.RENDER_SCENE_BUILD_SCENE, + ); + // Update the minimap again when the graph is done animating. + setTimeout( + function () { + this.minimap.update(); + }.bind(this), + tf_graph_layout.PARAMS.animation.duration, + ); + } + ready() { + super.ready(); + this._zoom = d3 + .zoom() + .on( + 'end', + function () { + if (this._zoomStartCoords) { + // Calculate the total distance dragged during the zoom event. + // If it is sufficiently small, then fire an event indicating + // that zooming has ended. Otherwise wait to fire the zoom end + // event, so that a mouse click registered as part of this zooming + // is ignored (as this mouse click was part of a zooming, and should + // not be used to indicate an actual click on the graph). + var dragDistance = Math.sqrt( + Math.pow(this._zoomStartCoords.x - this._zoomTransform.x, 2) + + Math.pow(this._zoomStartCoords.y - this._zoomTransform.y, 2), + ); + if (dragDistance < this._maxZoomDistanceForClick) { + this._fireEnableClick(); + } else { + setTimeout(this._fireEnableClick.bind(this), 50); + } + } + this._zoomStartCoords = null; + }.bind(this), + ) + .on( + 'zoom', + function () { + this._zoomTransform = d3.event.transform; + if (!this._zoomStartCoords) { + this._zoomStartCoords = this._zoomTransform; + this.fire('disable-click'); + } + this._zoomed = true; + d3.select(this.$.root).attr('transform', d3.event.transform.toString()); + this.x = d3.event.transform.x; + this.y = d3.event.transform.y; + this.k = d3.event.transform.k; + // Notify the minimap. + this.minimap.zoom(d3.event.transform); + }.bind(this), + ); + + d3.select(this.$.svg).call(this._addEventListener.bind(this)).on('dblclick.zoom', null); + d3.select(window).on( + 'resize', + function () { + // Notify the minimap that the user's window was resized. + // The minimap will figure out the new dimensions of the main svg + // and will use the existing translate and scale params. + this.minimap.zoom(); + }.bind(this), + ); + // Initialize the minimap. + this.minimap = (this.$.minimap as any).init( + this.$.svg, + this.$.root, + this._zoom, + tf_graph_layout.PARAMS.minimap.size, + tf_graph_layout.PARAMS.subscene.meta.labelHeight, + ); + + // Add keyboard event listener + this._addEventListener(); + } + + _addEventListener() { + let isDragging = false; + let startX, startY; + let lastTime = 0; + const smoothFactor = 0.2; // 控制平滑的因子 + const maxDelta = 800; // 限制滚动速度 + const svgElement = this.$.svg as SVGSVGElement; + svgElement.setAttribute('tabindex', '0'); + + svgElement.addEventListener('mousedown', (event: MouseEvent) => { + isDragging = true; + startX = event.clientX; + startY = event.clientY; + svgElement.focus(); + + svgElement.addEventListener('mousemove', (event: MouseEvent) => { + if (isDragging) { + this.x = Math.min(Math.max(this.x + (event.clientX - startX) / 2, -10000), 10000); + this.y = Math.min(Math.max(this.y + (event.clientY - startY) / 2, -10000), 10000); + this._moveView(); + startX = event.clientX; + startY = event.clientY; + } + }); + }); + window.addEventListener('mouseup', () => { + isDragging = false; + }); + svgElement.addEventListener('wheel', (event: WheelEvent) => { + const currentTime = performance.now(); + const deltaTime = currentTime - lastTime; + if (deltaTime > 16) { + // 确保每帧调用 + const deltaY = Math.sign(event.deltaY) * Math.min(Math.abs(event.deltaY), maxDelta); + this.y = Math.min(Math.max(this.y - deltaY * smoothFactor * 2, -10000), 10000); + this._moveView(); + lastTime = currentTime; + } + }); + svgElement.addEventListener('keydown', (event: KeyboardEvent) => { + switch (event.key) { + case 'w': + case 'W': + this._scaleStep = 1.1; + this._scaleView(this._scaleStep); + break; + case 's': + case 'S': + this._scaleStep = 0.9; + this._scaleView(this._scaleStep); + break; + case 'a': + case 'A': + this.x += this._step; + this._moveView(); + break; + case 'd': + case 'D': + this.x -= this._step; + this._moveView(); + break; + default: + return; // Exit if it's not an arrow key + } + }); + } + + _scaleView(scaleFactor: number) { + if (this._zoomTransform) { + const svgElement = this.$.svg as SVGSVGElement; + const currentTransform = d3.zoomTransform(svgElement); + const k = currentTransform.k === 0 ? 1 : currentTransform.k; + const [mouseX, mouseY] = [ + this.mouseX - svgElement.getBoundingClientRect().left, + this.mouseY - svgElement.getBoundingClientRect().top, + ]; + const translateX = (mouseX - currentTransform.x) / k; + const translateY = (mouseY - currentTransform.y) / k; + const newScale = currentTransform.k * scaleFactor; + this.x = mouseX - translateX * newScale; + this.y = mouseY - translateY * newScale; + const newTransform = d3.zoomIdentity.translate(this.x, this.y).scale(newScale); + d3.select(this.$.svg).call(d3.zoom().transform, newTransform); + d3.select(this.$.root).attr('transform', newTransform.toString()); + this._zoomTransform = newTransform; + this.minimap.zoom(newTransform); + } + } + + _moveView() { + if (this._zoomTransform) { + const svgElement = this.$.svg as SVGElement; + const currentTransform = d3.zoomTransform(svgElement); + const newTransform = d3.zoomIdentity.translate(this.x, this.y).scale(currentTransform.k); + d3.select(this.$.svg).call(d3.zoom().transform, newTransform); + d3.select(this.$.root).attr('transform', newTransform.toString()); + this._zoomTransform = newTransform; + // 通知小地图变更 + this.minimap.zoom(newTransform); + } + } + + override attached() { + this.set('_isAttached', true); + } + override detached() { + this.set('_isAttached', false); + } + @observe('renderHierarchy') + _renderHierarchyChanged() { + var renderHierarchy = this.renderHierarchy; + this._hasRenderHierarchyBeenFitOnce = false; + this._resetState(); + this._build(renderHierarchy); + } + + // Animation and fitting must come after the observer for the hierarchy changing because we must + // first build the render hierarchy. + @observe('_isAttached', 'renderHierarchy') + _animateAndFit() { + var isAttached = this._isAttached; + if (this._hasRenderHierarchyBeenFitOnce || !isAttached) { + // Do not animate and fit if the scene has already fitted this render hierarchy once. Or if + // the graph dashboard is not attached (in which case the scene lacks DOM info for fitting). + return; + } + // Fit to screen after the graph is done animating. + setTimeout(this.fit.bind(this), tf_graph_layout.PARAMS.animation.duration); + } + _updateLabels(showLabels) { + var mainGraphTitleElement = this.$$('.title') as HTMLElement; + var titleStyle = mainGraphTitleElement.style; + var auxTitleElement = this.$$('.auxTitle') as HTMLElement; + var auxTitleStyle = auxTitleElement.style; + var functionLibraryTitleStyle = (this.$$('.functionLibraryTitle') as HTMLElement).style; + const root = d3.select(this.$.svg); + var core = root.select('.' + tf_graph_scene.Class.Scene.GROUP + '>.' + tf_graph_scene.Class.Scene.CORE).node(); + // Only show labels if the graph is fully loaded. + if (showLabels && core && this.progress && this.progress.value === 100) { + var aux = + root.select('.' + tf_graph_scene.Class.Scene.GROUP + '>.' + tf_graph_scene.Class.Scene.INEXTRACT).node() || + root.select('.' + tf_graph_scene.Class.Scene.GROUP + '>.' + tf_graph_scene.Class.Scene.OUTEXTRACT).node(); + var coreX = (core as any).getCTM().e; + var auxX = aux ? (aux as any).getCTM().e : null; + titleStyle.display = 'inline'; + titleStyle.left = coreX + 'px'; + if (auxX !== null && auxX !== coreX) { + auxTitleStyle.display = 'inline'; + // Make sure that the aux title is positioned rightwards enough so as to + // prevent overlap with the main graph title. + auxX = Math.max(coreX + mainGraphTitleElement.getBoundingClientRect().width, auxX); + auxTitleStyle.left = auxX + 'px'; + } else { + auxTitleStyle.display = 'none'; + } + let functionLibrary = root + .select('.' + tf_graph_scene.Class.Scene.GROUP + '>.' + tf_graph_scene.Class.Scene.FUNCTION_LIBRARY) + .node(); + let functionLibraryX = functionLibrary ? (functionLibrary as any).getCTM().e : null; + if (functionLibraryX !== null && functionLibraryX !== auxX) { + functionLibraryTitleStyle.display = 'inline'; + // Make sure that the function library title is positioned rightwards + // enough so as to prevent overlap with other content. + functionLibraryX = Math.max(auxX + auxTitleElement.getBoundingClientRect().width, functionLibraryX); + functionLibraryTitleStyle.left = functionLibraryX + 'px'; + } else { + functionLibraryTitleStyle.display = 'none'; + } + } else { + titleStyle.display = 'none'; + auxTitleStyle.display = 'none'; + functionLibraryTitleStyle.display = 'none'; + } + } + fit() { + this._hasRenderHierarchyBeenFitOnce = true; + this._scaleStep = 1; + tf_graph_scene.fit( + this.$.svg, + this.$.root, + this._zoom, + function () { + this._zoomed = false; + }.bind(this), + ); + } + getImageBlob(): Promise { + return this.minimap.getImageBlob(); + } + isNodeSelected(n) { + return n === this.selectedNode; + } + isNodeHighlighted(n) { + return n === this.highlightedNode; + } + isNodeLinked(n) { + return n === this.linkedNode; + } + addAnnotationGroup(a, d, selection) { + var an = a.node.name; + this._annotationGroupIndex[an] = this._annotationGroupIndex[an] || {}; + this._annotationGroupIndex[an][d.node.name] = selection; + } + getAnnotationGroupsIndex(a) { + return this._annotationGroupIndex[a]; + } + removeAnnotationGroup(a, d) { + delete this._annotationGroupIndex[a.node.name][d.node.name]; + } + addNodeGroup(n, selection) { + this._nodeGroupIndex[n] = selection; + } + getNodeGroup(n) { + return this._nodeGroupIndex[n]; + } + removeNodeGroup(n) { + delete this._nodeGroupIndex[n]; + } + addEdgeGroup(n, selection) { + this._edgeGroupIndex[n] = selection; + } + getEdgeGroup(e) { + return this._edgeGroupIndex[e]; + } + /** + * Update node and annotation node of the given name. + * @param {String} n node name + */ + _updateNodeState(n) { + var node = this.getNode(n); + if (!node) { + return; + } + var nodeGroup = this.getNodeGroup(n); + if (nodeGroup) { + tf_graph_scene_node.stylize(nodeGroup, node, this as any); + } + if ( + node.node.type === (tf_graph.NodeType.META || tf_graph.NodeType.API_LIST) && + (node.node as any).associatedFunction && + !node.isLibraryFunction + ) { + // The node is that of a function call. Also link the node within the + // function library. This clarifies to the user that the library function + // is being used. + var libraryFunctionNodeName = tf_graph.FUNCTION_LIBRARY_NODE_PREFIX + (node.node as any).associatedFunction; + var functionGroup = d3.select( + '.' + + tf_graph_scene.Class.Scene.GROUP + + '>.' + + tf_graph_scene.Class.Scene.FUNCTION_LIBRARY + + ' g[data-name="' + + libraryFunctionNodeName + + '"]', + ); + tf_graph_scene_node.stylize(functionGroup, node, this as any); + } + var annotationGroupIndex = this.getAnnotationGroupsIndex(n); + _.each(annotationGroupIndex, (aGroup, hostName) => { + tf_graph_scene_node.stylize(aGroup, node, this as any, tf_graph_scene.Class.Annotation.NODE); + }); + } + /** + * Handles new node selection. 1) Updates the selected-state of each node, + * 2) triggers input tracing. + * @param selectedNode {string} The name of the newly selected node. + * @param oldSelectedNode {string} The name of the previously selected node. + * @private + */ + _selectedNodeChanged(selectedNode, oldSelectedNode) { + if (selectedNode === oldSelectedNode) { + return; + } + if (oldSelectedNode) { + this._updateNodeState(oldSelectedNode); + } + if (!selectedNode) { + this.linkedNode = ''; + return; + } + var node = this.renderHierarchy.hierarchy.node(selectedNode); + if (!node) { + return; + } + const linkNodes = node.nodeAttributes['_linked_node']; + if (Array.isArray(linkNodes)) { + let tempNode = ''; + let lastRenderNode: tf_graph_render.RenderNodeInfo | undefined = undefined; + let lastExpandStatus = false; + for (let linkNode of linkNodes) { + const renderLinkedNode = this.renderHierarchy.getRenderNodeByName(linkNode); + // Expand all ancestors of the linked node. + if (renderLinkedNode) { + lastRenderNode = renderLinkedNode; + lastExpandStatus = renderLinkedNode.expanded; + renderLinkedNode.expanded = true; + tempNode = linkNode; + } else { + break; + } + } + if (lastRenderNode) { + lastRenderNode.expanded = lastExpandStatus; + } + this.linkedNode = tempNode; + } else { + this.linkedNode = ''; + } + // Update the minimap to reflect the highlighted (selected) node. + (this.minimap as any).update(); + var nodeParents: string[] = []; + // Create list of all metanode parents of the selected node. + while (node.parentNode != null && node.parentNode.name != tf_graph.ROOT_NAME) { + node = (node as any).parentNode; + nodeParents.push(node.name); + } + // Ensure each parent metanode is built and expanded. + var topParentNodeToBeExpanded; + _.forEachRight(nodeParents, (parentName) => { + this.renderHierarchy.buildSubhierarchy(parentName); + var renderNode = this.renderHierarchy.getRenderNodeByName(parentName); + if (renderNode.node.isGroupNode && !renderNode.expanded) { + renderNode.expanded = true; + if (!topParentNodeToBeExpanded) { + topParentNodeToBeExpanded = renderNode; + } + } + }); + // If any expansion was needed to display this selected node, then + // inform the scene of the top-most expansion. + if (topParentNodeToBeExpanded) { + this.setNodeExpanded(topParentNodeToBeExpanded); + this._zoomed = true; + } + if (selectedNode) { + this._updateNodeState(selectedNode); + } + // Give time for any expanding to finish before panning to a node. + // Otherwise, the pan will be computed from incorrect measurements. + setTimeout(() => { + this.panToNode(selectedNode); + }, tf_graph_layout.PARAMS.animation.duration); + } + _highlightedNodeChanged(highlightedNode, oldHighlightedNode) { + if (highlightedNode === oldHighlightedNode) { + return; + } + if (highlightedNode) { + this._updateNodeState(highlightedNode); + } + if (oldHighlightedNode) { + this._updateNodeState(oldHighlightedNode); + } + } + _linkedNodeChanged(linkedNode, oldLinkedNode) { + if (linkedNode === oldLinkedNode) { + return; + } + if (oldLinkedNode) { + this._updateNodeState(oldLinkedNode); + } + if (linkedNode) { + this._updateNodeState(linkedNode); + } + this._build(this.renderHierarchy); + } + _onZoomChanged() { + this._updateLabels(!this._zoomed); + } + _fireEnableClick() { + this.fire('enable-click'); + } + + // When renderHierarchy changes, we need to first build the new SVG based + // on the new hierarchy (and it is asynchronous). We will let that observer + // update the input trace. + @observe('traceInputs', 'selectedNode') + _updateInputTrace() { + tf_graph_scene_node.updateInputTrace( + this.getGraphSvgRoot(), + this.renderHierarchy, + this.selectedNode, + this.traceInputs, + ); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts new file mode 100644 index 0000000000000000000000000000000000000000..97bebdb7dc72b8ad8dcad2691f596351c5dfdf52 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph/tf-graph.ts @@ -0,0 +1,625 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as d3 from 'd3'; +import * as _ from 'lodash'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tb_debug from '../tb_debug'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; +import * as tf_graph_render from '../tf_graph_common/render'; +import * as tf_graph_scene from '../tf_graph_common/scene'; +import * as tf_graph_util from '../tf_graph_common/util'; +import './tf-graph-scene'; +import { Selection } from '../tf_graph_controls/tf-graph-controls'; +import { getRouter } from '../tf_backend/router'; +import { fetchPbTxt, parseGraphPbTxt } from '../tf_graph_common/parser'; +import * as tf_hierarchy from '../tf_graph_common/hierarchy'; +import * as tf_graph_parser from '../tf_graph_common/parser'; + +@customElement('tf-graph') +class TfGraph extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + +
+
+ + +
+
+ `; + @property({ + type: Object, + notify: true, + observer: '_graphChanged', + }) + graphHierarchy: tf_graph_hierarchy.Hierarchy; + @property({ type: Object }) + basicGraph: tf_graph.SlimGraph; + @property({ type: Object }) + stats: object; + @property({ type: Object }) + devicesForStats: object; + @property({ type: Object }) + hierarchyParams: tf_graph_hierarchy.HierarchyParams; + @property({ + type: Object, + notify: true, + }) + progress: object; + @property({ type: String }) + override title: string; + @property({ + type: String, + notify: true, + }) + selectedNode: string; + @property({ + type: Object, + notify: true, + }) + selectedEdge: object; + @property({ type: Object }) + _lastSelectedEdgeGroup: any; + @property({ type: Object }) + _lastHighlightedEdgeGroup: any; + @property({ + type: String, + notify: true, + }) + highlightedNode: string; + @property({ + type: Object, + notify: true, + }) + highlightedEdge: tf_graph_render.EdgeData; + @property({ + type: Object, + readOnly: true, + notify: true, + }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + @property({ type: Boolean }) + traceInputs: boolean; + @property({ type: Array }) + nodeContextMenuItems: unknown[]; + @property({ + type: Number, + }) + _renderDepth: number = 1; + @property({ + type: Boolean, + }) + _allowGraphSelect: boolean = true; + @property({ + type: Object, + }) + edgeWidthFunction: any = ''; + @property({ + type: Object, + }) + handleNodeSelected: any = ''; + @property({ + type: Object, + }) + edgeLabelFunction: any = ''; + @property({ + type: Object, + }) + handleEdgeSelected: any = ''; + @property({ + type: Object, + }) + selection: Selection; + /** + * Pans to a node. Assumes that the node exists. + * @param nodeName {string} The name of the node to pan to. + */ + panToNode(nodeName) { + (this.$$('tf-graph-scene') as any).panToNode(nodeName); + } + @observe('graphHierarchy', 'edgeWidthFunction', 'handleNodeSelected', 'edgeLabelFunction', 'handleEdgeSelected') + _buildNewRenderHierarchy() { + var graphHierarchy = this.graphHierarchy; + if (!graphHierarchy) return; + this._buildRenderHierarchy(graphHierarchy); + } + @observe('stats', 'devicesForStats') + _statsChanged() { + var stats = this.stats; + var devicesForStats = this.devicesForStats; + if (this.graphHierarchy) { + if (stats && devicesForStats) { + tf_graph.joinStatsInfoWithGraph(this.basicGraph, stats as any, devicesForStats as any); + tf_graph_hierarchy.joinAndAggregateStats(this.graphHierarchy, stats as any); + } + // Recompute the rendering information. + this._buildRenderHierarchy(this.graphHierarchy); + } + } + ready() { + super.ready(); + + this.addEventListener('graph-select', this._graphSelected.bind(this)); + this.addEventListener('disable-click', this._disableClick.bind(this)); + this.addEventListener('enable-click', this._enableClick.bind(this)); + // Nodes + this.addEventListener('node-toggle-expand', this._nodeToggleExpand.bind(this)); + document.addEventListener('menu-expand-node-changed', this._handleMenuExpandNodeChanged.bind(this)); + document.addEventListener('parent-node-toggle-expand', this._parentNodeToggleExpand.bind(this)); + this.addEventListener('node-select', this._nodeSelected.bind(this)); + this.addEventListener('node-highlight', this._nodeHighlighted.bind(this)); + this.addEventListener('node-unhighlight', this._nodeUnhighlighted.bind(this)); + this.addEventListener('node-toggle-extract', this._nodeToggleExtract.bind(this)); + this.addEventListener('node-toggle-seriesgroup', this._nodeToggleSeriesGroup.bind(this)); + // Edges + this.addEventListener('edge-select', this._edgeSelected.bind(this)); + this.addEventListener('edge-highlight', this._edgeHighlighted.bind(this)); + this.addEventListener('edge-unhighlight', this._edgeUnhighlighted.bind(this)); + + // Annotations + + /* Note: currently highlighting/selecting annotation node has the same + * behavior as highlighting/selecting actual node so we point to the same + * set of event listeners. However, we might redesign this to be a bit + * different. + */ + this.addEventListener('annotation-select', this._nodeSelected.bind(this)); + this.addEventListener('annotation-highlight', this._nodeHighlighted.bind(this)); + this.addEventListener('annotation-unhighlight', this._nodeUnhighlighted.bind(this)); + } + _buildRenderHierarchy(graphHierarchy) { + if (graphHierarchy.root.type !== tf_graph.NodeType.META) { + // root must be metanode but sometimes Polymer's dom-if has not + // remove tf-graph element yet in + // and thus mistakenly pass non-metanode to this module. + return; + } + + // Certain Polymer property setter are dynamically generated and is not properly + // typed. + const anyThis = this as any; + + const renderGraph = tf_graph_util.time( + 'new tf_graph_render.Hierarchy', + () => { + const renderGraph = new tf_graph_render.RenderGraphInfo(graphHierarchy, !!this.stats /** displayingStats */); + renderGraph.edgeLabelFunction = this.edgeLabelFunction; + renderGraph.edgeWidthFunction = this.edgeWidthFunction; + // Producing the 'color by' parameters to be consumed + // by the tf-graph-controls panel. It contains information about the + // min and max values and their respective colors, as well as list + // of devices with their respective colors. + function getColorParamsFromScale(scale) { + return { + minValue: scale.domain()[0], + maxValue: scale.domain()[1], + startColor: scale.range()[0], + endColor: scale.range()[1], + }; + } + + return renderGraph; + }, + tb_debug.GraphDebugEventId.RENDER_BUILD_HIERARCHY, + ); + + anyThis._setRenderHierarchy(renderGraph); + } + _getVisible(name) { + if (!name) { + return name; + } + return this.renderHierarchy.getNearestVisibleAncestor(name); + } + fit() { + (this.$.scene as any).fit(); + } + getImageBlob(): Promise { + return (this.$.scene as any).getImageBlob(); + } + _graphChanged() { + if (!this.graphHierarchy) { + return; + } + + this.graphHierarchy.addListener(tf_graph_hierarchy.HierarchyEvent.TEMPLATES_UPDATED, () => { + (this.$.scene as any).nodeColorsChanged(); + }); + + // When a new graph is loaded, fire this event so that there is no + // info-card being displayed for the previously-loaded graph. + this.fire('graph-select'); + } + _graphSelected(event) { + // Graph selection is not allowed during an active zoom event, as the + // click seen during a zoom/pan is part of the zooming and does not + // indicate a user desire to click on a specific section of the graph. + if (this._allowGraphSelect) { + // this.set('selectedNode', null); + this.set('selectedEdge', null); + } + // Reset this variable as a bug in d3 zoom behavior can cause zoomend + // callback not to be called if a right-click happens during a zoom event. + this._allowGraphSelect = true; + } + _disableClick(event) { + this._allowGraphSelect = false; + } + _enableClick(event) { + this._allowGraphSelect = true; + } + @observe('selectedNode') + // Called when the selected node changes, ie there is a new selected node or + // the current one is unselected. + _selectedNodeChanged() { + var selectedNode = this.selectedNode; + if (this.handleNodeSelected) { + // A higher-level component provided a callback. Run it. + this.handleNodeSelected(selectedNode); + } + } + @observe('selectedNode') + async _menuSelectedNodeExpand() { + if (this.renderHierarchy.renderedOpNames.includes(this.selectedNode) || this.selectedNode == '') { + return; + } else { + const current = this.selectedNode; + const params = new URLSearchParams(); + if (this.selection.tag) { + params.set('tag', this.selection.tag); + } + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + if (hasBNode) { + params.set('node', this.selectedNode); + } else { + params.set('node', `N___${this.selectedNode}`); + } + const nodeMap = this.renderHierarchy.hierarchy.getNodeMap(); + const expandnodesPath = getRouter().pluginRouteForSrc('graph_ascend', '/expandnodes', params); + let nodeName = ''; + try { + const compomentsStr = await tf_graph_parser.fetchPbTxt(expandnodesPath); + let compoments; + try { + compoments = JSON.parse(new TextDecoder().decode(compomentsStr).replace(/'/g, '"')) as object; + } catch (e) { + console.error('Parse tooltips failed, please check the format of tooltips in the input vis file'); + } + if (compoments[1].length === 0 && compoments[2].length === 0) { + return; + } + for (const i of compoments[1]) { + nodeName = compoments[0] + i; + const renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); + if (nodeName in nodeMap && !renderNode.expanded) { + await this._nodeToggleExpand({ detail: { name: nodeName } }); + } + } + this.async(() => { + try { + this.set('selectedNode', ''); // 临时清空 + this.set('selectedNode', current); // 恢复原值 + } catch (e) { + console.error('Error during async set operation:', e); + } + }, 175); //代码会在延迟 175 毫秒后执行, 给浏览器足够的时间来处理多层展开带来的渲染和状态变化 + } catch (error) { + console.error('Error fetching expandnodesPath:', error); + } + } + } + + @observe('selectedEdge') + // Called when the selected edge changes, ie there is a new selected edge or + // the current one is unselected. + _selectedEdgeChanged() { + var selectedEdge = this.selectedEdge; + this._deselectPreviousEdge(); + // Visually mark this new edge as selected. + if (selectedEdge) { + this._lastSelectedEdgeGroup.classed(tf_graph_scene.Class.Edge.SELECTED, true); + // Update the color of the marker too if the edge has one. + this._updateMarkerOfSelectedEdge(selectedEdge, true); + } + if (this.handleEdgeSelected) { + // A higher-level component provided a callback. Run it. + this.handleEdgeSelected(selectedEdge); + } + } + @observe('highlightedEdge') + // Called when the highlighted edge changes. + _highlightedEdgeChange() { + let highlightedEdge = this.highlightedEdge; + this._lastHighlightedEdgeGroup.classed(tf_graph_scene.Class.Edge.HIGHLIGHTED, !!highlightedEdge); + if (highlightedEdge) { + this._updateMarkerOfSelectedEdge(highlightedEdge, false); + } else { + this._lastHighlightedEdgeGroup + .selectAll('path.' + tf_graph_scene.Class.Edge.LINE) + .each((d: tf_graph_render.EdgeData, i) => { + // Reset its marker. + if ( + d.label && + (d.v !== (this.selectedEdge as any)?.v || + d.w !== (this.selectedEdge as any)?.w || + d.id !== (this.selectedEdge as any)?.id) + ) { + const paths = this._lastHighlightedEdgeGroup.selectAll('path.edgeline'); + if (d.label.startMarkerId) { + paths.style('marker-start', `url(#${d.label.startMarkerId.replace('highlighted', 'dataflow')})`); + } + if (d.label.endMarkerId) { + paths.style('marker-end', `url(#${d.label.endMarkerId.replace('highlighted', 'dataflow')})`); + } + } + }); + } + } + // Called only when a new (non-null) node is selected. + _nodeSelected(event) { + if (this._allowGraphSelect) { + this.set('selectedNode', event.detail.name); + this.set('selectedEdge', null); + } + // Reset this variable as a bug in d3 zoom behavior can cause zoomend + // callback not to be called if a right-click happens during a zoom event. + this._allowGraphSelect = true; + } + _edgeSelected(event) { + if (this._allowGraphSelect) { + this.set('_lastSelectedEdgeGroup', event.detail.edgeGroup); + this.set('selectedEdge', event.detail.edgeData); + this.set('selectedNode', null); + } + // Reset this variable as a bug in d3 zoom behavior can cause zoomend + // callback not to be called if a right-click happens during a zoom event. + this._allowGraphSelect = true; + } + _nodeHighlighted(event) { + this.set('highlightedNode', event.detail.name); + } + _edgeHighlighted(event) { + if ( + event.detail.edgeData?.v === (this.selectedEdge as any)?.v && + event.detail.edgeData?.w === (this.selectedEdge as any)?.w && + event.detail.edgeData?.id === (this.selectedEdge as any)?.id + ) { + return; + } + this.set('_lastHighlightedEdgeGroup', event.detail.edgeGroup); + this.set('highlightedEdge', event.detail.edgeData); + } + _nodeUnhighlighted(event) { + this.set('highlightedNode', null); + } + _edgeUnhighlighted(event) { + this.set('highlightedEdge', null); + } + async _parentNodeToggleExpand(event) { + const nodeName = event.detail.nodeData.node.name; + const matched_node_link = event.detail.nodeData.node.matchedNodeLink; + if (matched_node_link) { + let matched = matched_node_link[matched_node_link.length - 1]; + this.set('selectedNode', matched); + } else { + const params = new URLSearchParams(); + params.set('run', this.selection.run); + params.set('node', nodeName); + if (this.selection.tag) { + params.set('tag', this.selection.tag); + } + params.set('batch', String(this.selection.batch === -1 ? -1 : this.selection.batch - 1)); + params.set('step', String(this.selection.step === -1 ? -1 : this.selection.step - 1)); + const graphPath = getRouter().pluginRouteForSrc('graph_ascend', '/parent', params); + const compomentsStr = await tf_graph_parser.fetchPbTxt(graphPath); + const compoments = new TextDecoder().decode(compomentsStr).replace(/'/g, '"'); + this.set('selectedNode', compoments); + } + } + _handleMenuExpandNodeChanged(event) { + const nodeName = event.detail.name; + const open = event.detail.open; + const renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); + if (!renderNode.expanded && open == 'expand') { + this._nodeToggleExpand({ detail: { name: nodeName } }); + } else if (renderNode.expanded && open == 'unexpand') { + this._nodeToggleExpand({ detail: { name: nodeName } }); + } + this._nodeSelected({ detail: { name: nodeName } }); + } + async _nodeToggleExpand(event) { + // Immediately select the node that is about to be expanded. + // this._nodeSelected(event); + // Compute the sub-hierarchy scene. + const nodeName = event.detail.name; + const renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); + // Op nodes are not expandable. + if (renderNode.node.type === tf_graph.NodeType.OP) { + return; + } + if (!renderNode.expanded && !this.renderHierarchy.checkSubhierarchy(nodeName)) { + const params = new URLSearchParams(); + params.set('run', this.selection.run); + params.set('node', renderNode.node.name || ''); + if (this.selection.tag) { + params.set('tag', this.selection.tag); + } + params.set('batch', String(this.selection.batch === -1 ? -1 : this.selection.batch - 1)); + params.set('step', String(this.selection.step === -1 ? -1 : this.selection.step - 1)); + const graphPath = getRouter().pluginRouteForSrc('graph_ascend', '/subgraph', params); + const arrayBuffer = await fetchPbTxt(graphPath); // 等待 fetchPbTxt 完成 + const graphDef = await parseGraphPbTxt(arrayBuffer); // 等待 parseGraphPbTxt 完成 + const slimGraph = await tf_graph.build(graphDef, tf_graph.DefaultBuildParams, undefined); // 等待 tf_graph.build 完成 + tf_hierarchy.update(this.renderHierarchy.hierarchy, slimGraph, nodeName); + this.renderHierarchy.buildSubhierarchy(nodeName, slimGraph); + renderNode.expanded = !renderNode.expanded; + this.async(() => { + (this.$.scene as any).setNodeExpanded(renderNode); + }, 75); + } else { + renderNode.expanded = !renderNode.expanded; + this.async(() => { + (this.$.scene as any).setNodeExpanded(renderNode); + }, 75); + } + } + + _nodeToggleExtract(event) { + // Toggle the include setting of the specified node appropriately. + var nodeName = event.detail.name; + this.nodeToggleExtract(nodeName); + } + nodeToggleExtract(nodeName: string) { + const renderNode = this.renderHierarchy.getRenderNodeByName(nodeName); + if (renderNode.node.include == tf_graph.InclusionType.INCLUDE) { + renderNode.node.include = tf_graph.InclusionType.EXCLUDE; + } else if (renderNode.node.include == tf_graph.InclusionType.EXCLUDE) { + renderNode.node.include = tf_graph.InclusionType.INCLUDE; + } else { + renderNode.node.include = this.renderHierarchy.isNodeAuxiliary(renderNode) + ? tf_graph.InclusionType.INCLUDE + : tf_graph.InclusionType.EXCLUDE; + } + // Rebuild the render hierarchy. + this._buildRenderHierarchy(this.graphHierarchy); + + tf_graph_util.notifyDebugEvent({ + actionId: tb_debug.GraphDebugEventId.NODE_AUXILIARY_EXTRACTION_CHANGED, + eventLabel: + renderNode.node.include === tf_graph.InclusionType.INCLUDE ? 'Auxiliary to Main' : 'Main to Auxiliary', + }); + } + _nodeToggleSeriesGroup(event) { + // Toggle the group setting of the specified node appropriately. + var nodeName = event.detail.name; + this.nodeToggleSeriesGroup(nodeName); + } + nodeToggleSeriesGroup(nodeName) { + this.set('progress', { + value: 0, + msg: '', + }); + var tracker = tf_graph_util.getTracker(this); + var hierarchyTracker = tf_graph_util.getSubtaskTracker(tracker, 100, 'Namespace hierarchy'); + + // Toggle the node's group type, setting to 'UNGROUP' if unspecified. + const newHierarchyParams = { + ...this.hierarchyParams, + seriesMap: this.graphHierarchy.buildSeriesGroupMapToggled(nodeName), + }; + + tf_graph_hierarchy.build(this.basicGraph, newHierarchyParams, hierarchyTracker).then( + function (graphHierarchy) { + this.set('graphHierarchy', graphHierarchy); + this._buildRenderHierarchy(this.graphHierarchy); + }.bind(this), + ); + } + _deselectPreviousEdge() { + const selectedSelector = '.' + tf_graph_scene.Class.Edge.SELECTED; + const selectedEdge = this.$.scene.shadowRoot?.querySelector(selectedSelector); + const selectedPaths = this.$.scene.shadowRoot?.querySelectorAll('path.edgeline'); + // Visually mark the previously selected edge (if any) as deselected. + !!selectedEdge && + d3 + .select(selectedEdge) + .classed(tf_graph_scene.Class.Edge.SELECTED, false) + .each((d: any, i) => { + // Reset its marker. + if (d.label && selectedPaths) { + const paths = d3.selectAll(selectedPaths); + if (d.label.startMarkerId) { + paths.style('marker-start', `url(#${d.label.startMarkerId})`); + } + if (d.label.endMarkerId) { + paths.style('marker-end', `url(#${d.label.endMarkerId})`); + } + } + }); + } + _updateMarkerOfSelectedEdge(selectedEdge, isSelected) { + if (selectedEdge.label) { + const statsName = isSelected ? 'selected-' : 'highlighted-'; + // The marker will vary based on the direction of the edge. + const markerId = selectedEdge.label.startMarkerId || selectedEdge.label.endMarkerId; + if (markerId) { + // Find the corresponding marker for a selected edge. + const selectedMarkerId = markerId.replace('dataflow-', statsName); + let selectedMarker = this.$.scene.shadowRoot?.querySelector('#' + selectedMarkerId) as HTMLElement; + if (!selectedMarker) { + // The marker for a selected edge of this size does not exist yet. Create it. + const originalMarker = this.$.scene.shadowRoot?.querySelector('#' + markerId); + selectedMarker = originalMarker?.cloneNode(true) as HTMLElement; + selectedMarker.setAttribute('id', selectedMarkerId); + selectedMarker.classList.add(`${statsName}arrowhead`); + originalMarker?.parentNode?.appendChild(selectedMarker); + } + // Make the path use this new marker while it is selected. + const markerAttribute = selectedEdge.label.startMarkerId ? 'marker-start' : 'marker-end'; + if (isSelected) { + this._lastSelectedEdgeGroup.selectAll('path.edgeline').style(markerAttribute, `url(#${selectedMarkerId})`); + } else { + this._lastHighlightedEdgeGroup.selectAll('path.edgeline').style(markerAttribute, `url(#${selectedMarkerId})`); + } + } + } + } + not(x) { + return !x; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts new file mode 100644 index 0000000000000000000000000000000000000000..9d5ca4b7f9dd04edb5d989c67a1edd73d840e314 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_board/tf-graph-board.ts @@ -0,0 +1,382 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import { customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import '../tf_graph/tf-graph'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; +import * as tf_graph_render from '../tf_graph_common/render'; +import '../tf_graph_info/tf-graph-info'; + +/** + * Element for putting tf-graph and tf-graph-info side by side. + * + * Example + * + */ +@customElement('tf-graph-board') +class TfGraphBoard extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + +
+
+ +
+
+ + +
+ + + + +
+
+
+ `; + @property({ type: Object }) + graphHierarchy: tf_graph_hierarchy.Hierarchy; + @property({ type: Object }) + graph: tf_graph.SlimGraph; + // TODO(psybuzz): ideally, this would be a required property and the component + // that owns and the graph loader should create these params. + @property({ type: Object }) + hierarchyParams: tf_graph_hierarchy.HierarchyParams = tf_graph_hierarchy.DefaultHierarchyParams; + @property({ type: Object }) + stats: object; + /** + * A number between 0 and 100 denoting the % of progress + * for the progress bar and the displayed message. + * @type {{value: number, msg: string}} + */ + @property({ type: Object }) + progress: object; + @property({ type: Boolean }) + traceInputs: boolean; + @property({ type: Boolean }) + autoExtractNodes: boolean; + @property({ + type: Object, + notify: true, + }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + // Whether debugger data is enabled for this instance of Tensorboard. + @property({ type: Boolean }) + debuggerDataEnabled: boolean; + @property({ + type: Array, + notify: true, + }) + // An array of alerts (in chronological order) provided by debugging libraries on when bad + // values (NaN, +/- Inf) appear. + debuggerNumericAlerts: unknown[]; + @property({ + type: Boolean, + notify: true, + }) + allStepsModeEnabled: boolean = false; + @property({ + type: Object, + }) + menu: any; + @property({ + type: Object, + }) + colorset: any; + @property({ + type: String, + notify: true, + }) + selectedNode: string; + @property({ + type: Object, + notify: true, + }) + selectedEdge: tf_graph_render.EdgeData; + @property({ + type: String, + }) + compatNodeTitle: string = 'TPU Compatibility'; + // A function with signature EdgeThicknessFunction that computes the + // thickness of a given edge. + @property({ type: Object }) + edgeWidthFunction: object; + // The enum value of the include property of the selected node. + @property({ type: Number }) + _selectedNodeInclude: number; + @property({ type: String }) + _highlightedNode: string; + // An optional function that takes a node selected event (whose `detail` + // property is the selected node ... which could be null if a node is + // deselected). Called whenever a node is selected or deselected. + @property({ type: Object }) + handleNodeSelected: object; + // An optional function that computes the label for an edge. Should + // implement the EdgeLabelFunction signature. + @property({ type: Object }) + edgeLabelFunction: object; + // An optional callback that implements the tf.graph.edge.EdgeSelectionCallback signature. If + // provided, edges are selectable, and this callback is run when an edge is selected. + @property({ type: Object }) + handleEdgeSelected: object; + @property({ type: Object }) + selection: object; + @property({ type: Object }) + tooltips: object; + + ready() { + super.ready(); + + const handle = this.shadowRoot?.getElementById('resize-handle') as HTMLElement; + + if (handle) { + let isResizing = false; + let lastDownX = 0; + let lastDownY = 0; + + handle.addEventListener('mousedown', (e: MouseEvent) => { + e.preventDefault(); + isResizing = true; + lastDownX = e.clientX; + lastDownY = e.clientY; + document.addEventListener('mousemove', handleMouseMove); + document.addEventListener('mouseup', handleMouseUp); + }); + + const handleMouseMove = (e: MouseEvent) => { + if (isResizing) { + const graphInfoDiv = this.shadowRoot?.getElementById('graph-info') as HTMLElement; + if (graphInfoDiv) { + const rect = graphInfoDiv.getBoundingClientRect(); + const width = rect.width + (lastDownX - e.clientX); + const height = rect.height + (e.clientY - lastDownY); + graphInfoDiv.style.width = `${Math.max(width, 100)}px`; + graphInfoDiv.style.height = `${Math.max(height, 65)}px`; + lastDownX = e.clientX; + lastDownY = e.clientY; + } + } + }; + + const handleMouseUp = () => { + isResizing = false; + document.removeEventListener('mousemove', handleMouseMove); + document.removeEventListener('mouseup', handleMouseUp); + }; + } + } + fit() { + (this.$.graph as any).fit(); + } + async downloadAsImage(filename: string) { + const blob = await (this.$.graph as any).getImageBlob(); + const element = document.createElement('a'); + (element as any).href = (URL as any).createObjectURL(blob); + element.download = filename; + element.click(); + URL.revokeObjectURL(element.href); + } + /** True if the progress is not complete yet (< 100 %). */ + _isNotComplete(progress) { + return progress.value < 100; + } + _getContainerClass(progress) { + var result = 'container'; + if (progress.error) { + result += ' error'; + } + if (this._isNotComplete(progress)) { + result += ' loading'; + } + return result; + } + _onNodeInclusionToggled(event) { + (this.$.graph as any).nodeToggleExtract(event.detail.name); + } + _onNodeSeriesGroupToggled(event) { + (this.$.graph as any).nodeToggleSeriesGroup(event.detail.name); + } + @observe('selectedNode', 'renderHierarchy') + _updateNodeInclude() { + const node = !this.renderHierarchy ? null : this.renderHierarchy.getNodeByName(this.selectedNode); + this._selectedNodeInclude = node ? node.include : tf_graph.InclusionType.UNSPECIFIED; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/colors.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/colors.ts new file mode 100644 index 0000000000000000000000000000000000000000..28d6d5fac2cfd214eea0c5905eace02463b87207 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/colors.ts @@ -0,0 +1,128 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +export let COLORS = [ + { + name: 'Google Blue', + color: '#4184f3', + active: '#3a53c5', + disabled: '#cad8fc', + }, + { + name: 'Google Red', + color: '#db4437', + active: '#8f2a0c', + disabled: '#e8c6c1', + }, + { + name: 'Google Yellow', + color: '#f4b400', + active: '#db9200', + disabled: '#f7e8b0', + }, + { + name: 'Google Green', + color: '#0f9d58', + active: '#488046', + disabled: '#c2e1cc', + }, + { + name: 'Purple', + color: '#aa46bb', + active: '#5c1398', + disabled: '#d7bce6', + }, + { + name: 'Teal', + color: '#00abc0', + active: '#47828e', + disabled: '#c2eaf2', + }, + { + name: 'Deep Orange', + color: '#ff6f42', + active: '#ca4a06', + disabled: '#f2cbba', + }, + { + name: 'Lime', + color: '#9d9c23', + active: '#7f771d', + disabled: '#f1f4c2', + }, + { + name: 'Indigo', + color: '#5b6abf', + active: '#3e47a9', + disabled: '#c5c8e8', + }, + { + name: 'Pink', + color: '#ef6191', + active: '#ca1c60', + disabled: '#e9b9ce', + }, + { + name: 'Deep Teal', + color: '#00786a', + active: '#2b4f43', + disabled: '#bededa', + }, + { + name: 'Deep Pink', + color: '#c1175a', + active: '#75084f', + disabled: '#de8cae', + }, + { + name: 'Gray', + color: '#9E9E9E', //500 + active: '#424242', //800 + disabled: 'F5F5F5', //100 + }, +].reduce((m, c) => { + m[c.name] = c; + return m; +}, {}); +/** + * Mapping from op category to color palette name + * e.g., OP_GROUP_COLORS['state_ops'] = 'Google Blue'; + */ +export let OP_GROUP_COLORS = [ + { + color: 'Google Red', + groups: [ + 'gen_legacy_ops', + 'legacy_ops', + 'legacy_flogs_input', + 'legacy_image_input', + 'legacy_input_example_input', + 'legacy_sequence_input', + 'legacy_seti_input_input', + ], + }, + {color: 'Deep Orange', groups: ['constant_ops']}, + {color: 'Indigo', groups: ['state_ops']}, + {color: 'Purple', groups: ['nn_ops', 'nn']}, + {color: 'Google Green', groups: ['math_ops']}, + {color: 'Lime', groups: ['array_ops']}, + {color: 'Teal', groups: ['control_flow_ops', 'data_flow_ops']}, + {color: 'Pink', groups: ['summary_ops']}, + {color: 'Deep Pink', groups: ['io_ops']}, +].reduce((m, c) => { + c.groups.forEach(function (group) { + m[group] = c.color; + }); + return m; +}, {}); diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/common.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/common.ts new file mode 100644 index 0000000000000000000000000000000000000000..b33b3e99fa69ef5d5188b430bb5cdbcd94fad9b6 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/common.ts @@ -0,0 +1,285 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * @fileoverview Common interfaces for the tensorflow graph visualizer. + */ +import * as d3 from 'd3'; + +export interface ProgressTracker { + updateProgress(incrementValue: number): void; + setMessage(msg: string): void; + reportError(msg: string, err: Error): void; +} +// Note that tf-graph-control depends on the value of the enum. +// Polymer does not let one use JS variable as a prop. +export enum SelectionType { + OP_GRAPH = 'op_graph', + CONCEPTUAL_GRAPH = 'conceptual_graph', + PROFILE = 'profile', +} + +/** Enums element class of objects in the scene */ +export let Class = { + Node: { + // element that contains nodes. + CONTAINER: 'nodes', + // element that contains detail about a node. + GROUP: 'node', + // element that contains visual elements (like rect, ellipse). + SHAPE: 'nodeshape', + OUTER: 'outer', + // <*> element(s) under SHAPE that should receive color updates. + COLOR_TARGET: 'nodecolortarget', + // element showing the node's label. + LABEL: 'nodelabel', + // element that contains all visuals for the expand/collapse + // button for expandable group nodes. + BUTTON_CONTAINER: 'buttoncontainer', + // element that surrounds expand/collapse buttons. + BUTTON_CIRCLE: 'buttoncircle', + // element of the expand button. + EXPAND_BUTTON: 'expandbutton', + // element of the collapse button. + COLLAPSE_BUTTON: 'collapsebutton', + }, + Edge: { + CONTAINER: 'edges', + GROUP: 'edge', + LINE: 'edgeline', + REFERENCE_EDGE: 'referenceedge', + REF_LINE: 'refline', + SELECTABLE: 'selectableedge', + SELECTED: 'selectededge', + STRUCTURAL: 'structural', + HIGHLIGHTED: 'highlighted', + }, + Annotation: { + OUTBOX: 'out-annotations', + INBOX: 'in-annotations', + GROUP: 'annotation', + NODE: 'annotation-node', + EDGE: 'annotation-edge', + CONTROL_EDGE: 'annotation-control-edge', + LABEL: 'annotation-label', + ELLIPSIS: 'annotation-ellipsis', + }, + Scene: { + GROUP: 'scene', + CORE: 'core', + FUNCTION_LIBRARY: 'function-library', + INEXTRACT: 'in-extract', + OUTEXTRACT: 'out-extract', + }, + Subscene: { GROUP: 'subscene' }, + OPNODE: 'op', + METANODE: 'meta', + SERIESNODE: 'series', + BRIDGENODE: 'bridge', + ELLIPSISNODE: 'ellipsis', + API_LIST: 'api_list', +}; + +// Please keep this in sync with tf-graph-scene.html.ts. +export const FontSizeInPx: Record> = { + Edge: { + LABEL: 3.5, + }, + Annotation: { + LABEL: 5, + }, + Node: { + EXPANDED_LABEL: 9, + SERIES_LABEL: 8, + OP_LABEL: 6, + HEALTH_PILL_STAT_LABEL: 4, + }, +}; + +export const SVG_NAMESPACE = 'http://www.w3.org/2000/svg'; + +/** + * Given a container d3 selection, select a child element of a given tag and + * class. If multiple children matches the tag and class name, returns only + * the first one. + * + * @param container + * @param tagName tag name. + * @param className (optional) Class name or list of class names. + * @return selection of the element, or an empty selection + */ +export function selectChild( + container, + tagName: string, + className?: string | string[], +): d3.Selection { + let children = container.node().childNodes; + for (let i = 0; i < children.length; i++) { + let child = children[i]; + if (child.tagName === tagName) { + if (className instanceof Array) { + let hasAllClasses = true; + for (let j = 0; j < className.length; j++) { + hasAllClasses = hasAllClasses && child.classList.contains(className[j]); + } + if (hasAllClasses) { + return d3.select(child); + } + } else if (!className || child.classList.contains(className)) { + return d3.select(child); + } + } + } + return d3.select(null); +} + +/** + * Given a container d3 selection, select a child svg element of a given tag + * and class if exists or append / insert one otherwise. If multiple children + * matches the tag and class name, returns only the first one. + * + * @param container + * @param tagName tag name. + * @param className (optional) Class name or a list of class names. + * @param before (optional) reference DOM node for insertion. + * @return selection of the element + */ +export function selectOrCreateChild( + container, + tagName: string, + className?: string | string[], + before?, +): d3.Selection { + let child = selectChild(container, tagName, className); + if (!child.empty()) { + return child; + } + let newElement = document.createElementNS('http://www.w3.org/2000/svg', tagName); + if (className instanceof Array) { + for (let i = 0; i < className.length; i++) { + newElement.classList.add(className[i]); + } + } else { + newElement.classList.add(className!); + } + if (before) { + // if before exists, insert + container.node().insertBefore(newElement, before); + } else { + // otherwise, append + container.node().appendChild(newElement); + } + return ( + d3 + .select(newElement) + // need to bind data to emulate d3_selection.append + .datum(container.datum()) + ); +} + +/** + * Execution stats for the node. + */ +export class NodeStats { + constructor(outputSize: number[][]) { + this.outputSize = outputSize; + } + /** + * Add the start and end time for a particular kernel execution of this op. + * Ops can have multiple kernel executions within the same session run. + */ + addExecutionTime(startTime: number, endTime: number) { + if (this.startTime != null) { + this.startTime = Math.min(this.startTime, startTime); + } else { + this.startTime = startTime; + } + if (this.endTime != null) { + this.endTime = Math.max(this.endTime, endTime); + } else { + this.endTime = endTime; + } + } + /** + * Add the bytes allocated for a particular kernel execution of this op. + * Ops can have multiple kernel executions within the same session run. + */ + addBytesAllocation(totalBytes: number) { + if (this.totalBytes != null) { + this.totalBytes = Math.max(this.totalBytes, totalBytes); + } else { + this.totalBytes = totalBytes; + } + } + /** + * Absolute start time for the very first kernel execution of this op. + */ + startTime: number; + /** + * Absolute end time for the very last kernel execution of this op. + */ + endTime: number; + /** + * Total number of bytes used for the node. Sum of all children + * if it is a Group node. + */ + totalBytes = 0; + /** + * The shape of each output tensors, if there are any. + * Empty if it is a Group node. + */ + outputSize: number[][]; + /** + * Combines the specified stats with the current stats. + * Modifies the current object. This method is used to + * compute aggregate stats for group nodes. + */ + combine(stats: NodeStats): void { + if (stats.totalBytes != null) { + this.totalBytes += stats.totalBytes; + } + if (stats.getTotalMicros() != null) { + this.addExecutionTime(stats.startTime, stats.endTime); + } + } + /** + * Total number of compute time in microseconds used for the node. + * Sum of all children if it is a Group node. Null if it is unknown. + * This method can not be scaffolded under a getter attribute because + * ECMAScript 5 does not support getter attributes. + */ + getTotalMicros(): number { + if (this.startTime == null || this.endTime == null) { + return null!; + } + return this.endTime - this.startTime; + } +} + +/** The minimum stroke width of an edge. */ +export const MIN_EDGE_WIDTH = 0.75; +/** The maximum stroke width of an edge. */ +export const MAX_EDGE_WIDTH = 12; +/** The exponent used in the power scale for edge thickness. */ +const EDGE_WIDTH_SCALE_EXPONENT = 0.3; +/** The domain (min and max value) for the edge width. */ +const DOMAIN_EDGE_WIDTH_SCALE = [1, 5000000]; +export const EDGE_WIDTH_SIZE_BASED_SCALE: d3.ScalePower = d3 + .scalePow() + .exponent(EDGE_WIDTH_SCALE_EXPONENT) + .domain(DOMAIN_EDGE_WIDTH_SCALE) + .range([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]) + .clamp(true); + +export let globalTooltips: { [key: string]: string } = {}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/contextmenu.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/contextmenu.ts new file mode 100644 index 0000000000000000000000000000000000000000..c1cd9e2c57a2dbf0a472d78792302038bc86abf2 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/contextmenu.ts @@ -0,0 +1,101 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; +import { TfGraphScene } from './tf-graph-scene'; + +export interface TitleFunction { + (data: any): string; +} +/** Function that takes action based on item clicked in the context menu. */ +export interface ActionFunction { + (elem: any, d: any, i: number): void; +} +/** + * The interface for an item in the context menu + */ +export interface ContextMenuItem { + title: TitleFunction; + action: ActionFunction; +} +/** + * Returns the top and left distance of the scene element from the top left + * corner of the screen. + */ +function getOffset(sceneElement) { + let leftDistance = 0; + let topDistance = 0; + let currentElement = sceneElement; + while (currentElement && currentElement.offsetLeft >= 0 && currentElement.offsetTop >= 0) { + leftDistance += currentElement.offsetLeft - currentElement.scrollLeft; + topDistance += currentElement.offsetTop - currentElement.scrollTop; + currentElement = currentElement.offsetParent; + } + return { + left: leftDistance, + top: topDistance, + }; +} +/** + * Returns the event listener, which can be used as an argument for the d3 + * selection.on function. Renders the context menu that is to be displayed + * in response to the event. + */ +export function getMenu(sceneElement: TfGraphScene, menu: ContextMenuItem[], nodeData) { + const menuNode = sceneElement.getContextMenu(); + const menuSelection = d3.select(sceneElement.getContextMenu()); + // Function called to populate the context menu. + return function (data, index: number): void { + // Position and display the menu. + let event = d3.event; + const sceneOffset = getOffset(sceneElement); + menuSelection + .style('display', 'block') + .style('left', event.clientX - sceneOffset.left + 1 + 'px') + .style('top', event.clientY - sceneOffset.top + 1 + 'px'); + // Stop the event from propagating further. + event.preventDefault(); + event.stopPropagation(); + function maybeCloseMenu(event?: any) { + if (event && event.composedPath().includes(menuNode)) { + return; + } + menuSelection.style('display', 'none'); + document.body.removeEventListener('mousedown', maybeCloseMenu, { + capture: true, + }); + } + // Dismiss and remove the click listener as soon as there is a mousedown + // on the document. We use capture listener so no component can stop + // context menu from dismissing due to stopped propagation. + document.body.addEventListener('mousedown', maybeCloseMenu, { + capture: true, + }); + // Add provided items to the context menu. + menuSelection.text(''); + let list = menuSelection.append('ul'); + list + .selectAll('li') + .data(menu) + .enter() + .append('li') + .on('click', (d, i) => { + sceneElement.fire('parent-node-toggle-expand', { nodeData }); + maybeCloseMenu(); + }) + .text(function (d) { + return '展开对应侧节点'; + }); + }; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/edge.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/edge.ts new file mode 100644 index 0000000000000000000000000000000000000000..58c633a0db1495d06ddabfea903fb5cde831d273 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/edge.ts @@ -0,0 +1,436 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; +import { graphlib } from 'dagre'; +import * as _ from 'lodash'; +import * as tf_graph_common from './common'; +import { Class, MAX_EDGE_WIDTH, MIN_EDGE_WIDTH } from './common'; +import { BaseEdge, EDGE_KEY_DELIM, Metaedge, OpNode, TensorShape } from './graph'; +import * as render from './render'; +import { EdgeData } from './render'; +import { TfGraphScene } from './tf-graph-scene'; + +/** Delimiter between dimensions when showing sizes of tensors. */ +const TENSOR_SHAPE_DELIM = '\u00D7'; + +let arrowheadMap = d3 + .scaleQuantize() + .domain([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]) + .range(['small', 'medium', 'large', 'xlarge']); +/** Minimum stroke width to put edge labels in the middle of edges */ +const CENTER_EDGE_LABEL_MIN_STROKE_WIDTH = 2.5; + +/** + * Function run when an edge is selected. + */ +export interface EdgeSelectionCallback { + (edgeData: EdgeData): void; +} +export function getEdgeKey(edgeObj: EdgeData) { + return edgeObj.v + EDGE_KEY_DELIM + edgeObj.w + EDGE_KEY_DELIM + edgeObj.id; +} +/** + * Select or Create a 'g.edges' group to a given sceneGroup + * and builds a number of 'g.edge' groups inside the group. + * + * Structure Pattern: + * + * + * + * + * + * ... + * + * + * + * @param sceneGroup container + * @param graph + * @param sceneElement polymer element. + * @return selection of the created nodeGroups + */ +export function buildGroup(sceneGroup, graph: graphlib.Graph, sceneElement: TfGraphScene) { + const sceneComponent = sceneElement as any; + let edges: EdgeData[] = []; + edges = _.reduce( + graph.edges(), + (edges, edgeObj) => { + let edgeLabel = graph.edge(edgeObj); + edges.push({ + v: edgeObj.v, + w: edgeObj.w, + id: edgeObj.name, + label: edgeLabel, + }); + return edges; + }, + edges, + ); + let container = tf_graph_common.selectOrCreateChild(sceneGroup, 'g', Class.Edge.CONTAINER); + // Select all children and join with data. + // (Note that all children of g.edges are g.edge) + let edgeGroups = (container as any) + .selectAll(function () { + return this.childNodes; + }) + .data(edges, getEdgeKey); + // Make edges a group to support rendering multiple lines for metaedge + edgeGroups + .enter() + .append('g') + .attr('class', Class.Edge.GROUP) + .attr('data-edge', getEdgeKey) + .each(function (d: EdgeData) { + let edgeGroup = d3.select(this); + d.label.edgeGroup = edgeGroup; + // index node group for quick highlighting + sceneComponent._edgeGroupIndex[getEdgeKey(d)] = edgeGroup; + + edgeGroup + .on('click', (d) => { + // Stop this event's propagation so that it isn't also considered + // a graph-select. + (d3.event).stopPropagation(); + sceneComponent.fire('edge-select', { + edgeData: d, + edgeGroup: edgeGroup, + }); + }) + .on('mouseover', (d) => { + sceneElement.fire('edge-highlight', { + edgeData: d, + edgeGroup: edgeGroup, + }); + }) + .on('mouseout', (d) => { + sceneElement.fire('edge-unhighlight', { + edgeData: d, + edgeGroup: edgeGroup, + }); + }); + // Add line during enter because we're assuming that type of line + // normally does not change. + appendEdge(edgeGroup, d, sceneComponent); + }) + .merge(edgeGroups) + .each(function () { + position(sceneElement, this); + }) + .each(function (d) { + stylize(d3.select(this), d, sceneComponent); + }); + edgeGroups + .exit() + .each((d) => { + delete sceneComponent._edgeGroupIndex[getEdgeKey(d)]; + }) + .remove(); + return edgeGroups; +} +/** + * Returns the label for the given base edge. + * The label is the shape of the underlying tensor. + */ +export function getLabelForBaseEdge(baseEdge: BaseEdge, renderInfo: render.RenderGraphInfo): string { + const outTensorKey = baseEdge.outputTensorKey; + let shape: TensorShape = []; + if (outTensorKey && outTensorKey.startsWith('[')) { + shape = JSON.parse(outTensorKey) as TensorShape; + } else { + let node = renderInfo.getNodeByName(baseEdge.v!); + if (node.outputShapes == null || _.isEmpty(node.outputShapes)) { + return null!; + } + shape = node.outputShapes[baseEdge.outputTensorKey]; + } + if (shape == null) { + return null!; + } + if (shape.length === 0) { + return 'scalar'; + } + return shape + .map((size) => { + return size === -1 ? '?' : size; + }) + .join(TENSOR_SHAPE_DELIM); +} +/** + * Creates the label for the given metaedge. If the metaedge consists + * of only 1 tensor, and it's shape is known, the label will contain that + * shape. Otherwise, the label will say the number of tensors in the metaedge. + */ +export function getLabelForEdge(metaedge: Metaedge, renderInfo: render.RenderGraphInfo): string { + if (renderInfo.edgeLabelFunction) { + // The user has specified a means of computing the label. + return renderInfo.edgeLabelFunction(metaedge, renderInfo); + } + // Compute the label based on either tensor count or size. + let isMultiEdge = metaedge.baseEdgeList.length > 1; + return isMultiEdge + ? metaedge.baseEdgeList.length + ' tensors' + : getLabelForBaseEdge(metaedge.baseEdgeList[0], renderInfo); +} +/** + * Computes the index into a set of points that constitute a path for which the + * distance along the path from the initial point is as large as possible + * without exceeding the length. This function was introduced after the + * native getPathSegAtLength method got deprecated by SVG 2. + * @param points Array of path control points. A point has x and y properties. + * Must be of length at least 2. + * @param length The length (float). + * @param lineFunc A function that takes points and returns the "d" attribute + * of a path made from connecting the points. + * @return The index into the points array. + */ +function getPathSegmentIndexAtLength( + points: render.Point[], + length: number, + lineFunc: (points: render.Point[]) => string, +): number { + const path = document.createElementNS(tf_graph_common.SVG_NAMESPACE, 'path'); + for (let i = 1; i < points.length; i++) { + path.setAttribute('d', lineFunc(points.slice(0, i))); + if (path.getTotalLength() > length) { + // This many points has already exceeded the length. + return i - 1; + } + } + // The entire path is shorter than the specified length. + return points.length - 1; +} +/** + * Shortens the path enought such that the tip of the start/end marker will + * point to the start/end of the path. The marker can be of arbitrary size. + * + * @param points Array of path control points. + * @param marker D3 selection of the svg element. + * @param isStart Is the marker a `start-marker`. If false, the marker is + * an `end-marker`. + * @return The new array of control points. + */ +function adjustPathPointsForMarker( + points: render.Point[], + marker: d3.Selection, + isStart: boolean, +): render.Point[] { + let lineFunc = d3 + .line() + .x((d) => d.x) + .y((d) => d.y); + let path = d3.select(document.createElementNS('http://www.w3.org/2000/svg', 'path')).attr('d', lineFunc(points)!); + let markerWidth = +marker.attr('markerWidth'); + let viewBox = marker.attr('viewBox').split(' ').map(Number); + let viewBoxWidth = viewBox[2] - viewBox[0]; + let refX = +marker.attr('refX'); + let pathNode = path.node(); + if (isStart) { + // The edge flows downwards. Do not make the edge go the whole way, lest we + // clobber the arrowhead. + const fractionStickingOut = 1 - refX / viewBoxWidth; + const length = markerWidth * fractionStickingOut; + const point = pathNode.getPointAtLength(length); + // Figure out how many segments of the path we need to remove in order + // to shorten the path. + // @ts-ignore TS2345: Argument of type 'Line' is not assignable to parameter of type '(points: Point[]) => string'. + const segIndex = getPathSegmentIndexAtLength(points, length, lineFunc); + // Update the very first segment. + points[segIndex - 1] = { x: point.x, y: point.y }; + // Ignore every point before segIndex - 1. + return points.slice(segIndex - 1); + } else { + // The edge flows upwards. Do not make the edge go the whole way, lest we + // clobber the arrowhead. + const fractionStickingOut = 1 - refX / viewBoxWidth; + const length = pathNode.getTotalLength() - markerWidth * fractionStickingOut; + const point = pathNode.getPointAtLength(length); + // Figure out how many segments of the path we need to remove in order + // to shorten the path. + // @ts-ignore TS2345: Argument of type 'Line' is not assignable to parameter of type '(points: Point[]) => string'. + const segIndex = getPathSegmentIndexAtLength(points, length, lineFunc); + // Update the very last segment. + points[segIndex] = { x: point.x, y: point.y }; + // Ignore every point after segIndex. + return points.slice(0, segIndex + 1); + } +} +/** + * For a given d3 selection and data object, create a path to represent the + * edge described in d.label. + * + * If d.label is defined, it will be a RenderMetaedgeInfo instance. It + * will sometimes be undefined, for example for some Annotation edges for which + * there is no underlying Metaedge in the hierarchical graph. + */ +export function appendEdge( + edgeGroup, + d: EdgeData, + sceneElement: { + renderHierarchy: render.RenderGraphInfo; + handleEdgeSelected: Function; + }, + edgeClass?: string, +) { + edgeClass = edgeClass || Class.Edge.LINE; // set default type + if (d.label && d.label.structural) { + edgeClass += ' ' + Class.Edge.STRUCTURAL; + } + if (d.label && d.label.metaedge && d.label.metaedge.numRefEdges) { + edgeClass += ' ' + Class.Edge.REFERENCE_EDGE; + } + if (sceneElement.handleEdgeSelected) { + // The user has opted to make edges selectable. + edgeClass += ' ' + Class.Edge.SELECTABLE; + } + // Give the path a unique id, which will be used to link + // the textPath (edge label) to this path. + let pathId = 'path_' + getEdgeKey(d); + let strokeWidth; + if (sceneElement.renderHierarchy.edgeWidthFunction) { + // Compute edge thickness based on the user-specified method. + strokeWidth = sceneElement.renderHierarchy.edgeWidthFunction(d, edgeClass); + } else { + // Encode tensor size within edge thickness. + let size = 1; + if (d.label != null && d.label.metaedge != null) { + // There is an underlying Metaedge. + size = d.label.metaedge.totalSize; + } + strokeWidth = sceneElement.renderHierarchy.edgeWidthSizedBasedScale(size); + } + let path = edgeGroup + .append('path') + .attr('id', pathId) + .attr('class', edgeClass) + .style('stroke-width', strokeWidth + 'px'); + // Check if there is a reference edge and add an arrowhead of the right size. + if (d.label && d.label.metaedge) { + if (d.label.metaedge.numRefEdges) { + // We have a reference edge. + const markerId = `reference-arrowhead-${arrowheadMap(strokeWidth)}`; + path.style('marker-start', `url(#${markerId})`); + d.label.startMarkerId = markerId; + } else { + // We have a dataflow edge. + const markerId = `dataflow-arrowhead-${arrowheadMap(strokeWidth)}`; + path.style('marker-end', `url(#${markerId})`); + d.label.endMarkerId = markerId; + } + } + if (d.label == null || d.label.metaedge == null) { + // There is no associated metaedge, thus no text. + // This happens for annotation edges. + return; + } + let labelForEdge = getLabelForEdge(d.label.metaedge, sceneElement.renderHierarchy); + if (labelForEdge == null) { + // We have no information to show on this edge. + return; + } + // Put edge label in the middle of edge only if the edge is thick enough. + let baseline = strokeWidth > CENTER_EDGE_LABEL_MIN_STROKE_WIDTH ? 'central' : 'text-after-edge'; + edgeGroup + .append('text') + .append('textPath') + .attr('xlink:href', '#' + pathId) + .attr('startOffset', '50%') + .attr('text-anchor', 'middle') + .attr('dominant-baseline', 'central') + .text(labelForEdge); +} +export let interpolate: d3.Line<{ + x: number; + y: number; +}> = d3 + .line<{ + x: number; + y: number; + }>() + .curve(d3.curveBasis) + .x((d) => { + return d.x; + }) + .y((d) => { + return d.y; + }); +/** + * Returns a tween interpolator for the endpoint of an edge path. + */ +function getEdgePathInterpolator( + component: HTMLElement, + renderPath: SVGPathElement, + d: EdgeData, + i: number, + a: SVGPathElement[], +) { + let renderMetaedgeInfo = d.label; + let adjoiningMetaedge = renderMetaedgeInfo.adjoiningMetaedge; + let points = renderMetaedgeInfo.points; + // Adjust the path so that start/end markers point to the end + // of the path. + const { shadowRoot } = component; + if (d.label.startMarkerId) { + points = adjustPathPointsForMarker( + points, + d3.select(shadowRoot?.querySelector('#' + d.label.startMarkerId)!), + true, + ); + } + if (d.label.endMarkerId) { + points = adjustPathPointsForMarker(points, d3.select(shadowRoot?.querySelector('#' + d.label.endMarkerId)!), false); + } + if (!adjoiningMetaedge) { + return d3.interpolate(a, interpolate(points)!); + } + // Get the adjoining path that matches the adjoining metaedge. + let adjoiningPath = (adjoiningMetaedge.edgeGroup.node()).firstChild; + // Find the desired SVGPoint along the adjoining path, then convert those + // coordinates into the space of the renderPath using its Current + // Transformation Matrix (CTM). + let inbound = renderMetaedgeInfo.metaedge.inbound; + return function (t) { + let adjoiningPoint = adjoiningPath + .getPointAtLength(inbound ? adjoiningPath.getTotalLength() : 0) + .matrixTransform(adjoiningPath.getCTM()!) + .matrixTransform(renderPath.getCTM()?.inverse()!); + // Update the relevant point in the renderMetaedgeInfo's points list, then + // re-interpolate the path. + let index = inbound ? 0 : points.length - 1; + points[index].x = adjoiningPoint.x; + points[index].y = adjoiningPoint.y; + let dPath = interpolate(points); + return dPath; + }; +} +function position(component: HTMLElement, edgeGroup: HTMLElement) { + d3.select(edgeGroup) + .select('path.' + Class.Edge.LINE) + .transition() + // @ts-ignore TS2769: No overload matches this call. complicated return type mismatch issue + .attrTween('d', function (d: EdgeData, i: number, a: SVGPathElement[]) { + return getEdgePathInterpolator(component, this as SVGPathElement, d, i, a); + }); +} +/** + * For a given d3 selection and data object, mark the edge as a control + * dependency if it contains only control edges. + * + * d's label property will be a RenderMetaedgeInfo object. + */ +export function stylize(edgeGroup, d: EdgeData, sceneElement: TfGraphScene) { + // const isHighlighted = sceneElement.isEdgeHighlighted(d); + edgeGroup.classed('faded', d.label.isFadedOut); + let metaedge = d.label.metaedge; + edgeGroup.select('path.' + Class.Edge.LINE).classed('control-dep', metaedge && !metaedge.numRegularEdges); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/graph.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/graph.ts new file mode 100644 index 0000000000000000000000000000000000000000..fc34e13121acb816ff246fc125d7ada603322ca4 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/graph.ts @@ -0,0 +1,1463 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import { graphlib } from 'dagre'; +import * as _ from 'lodash'; +import * as tb_debug from '../tb_debug'; +import { NodeStats, ProgressTracker } from './common'; +import { Hierarchy } from './hierarchy'; +import * as tf_graph_proto from './proto'; +import * as tf_graph_util from './util'; + +export const NAMESPACE_DELIM = '/'; +export const ROOT_NAME = '__root__'; +export const FUNCTION_LIBRARY_NODE_PREFIX = '__function_library__'; +/** Attribute key used for storing attributes that are too large. */ +export const LARGE_ATTRS_KEY = '_too_large_attrs'; +/** Precision attributes are used to represent the color of nodes. */ +export const NODE_TYPE = 'node_type'; +export const PRECISION_INDEX = 'precision_index'; +export const MATCHED_NODE_LINK = 'matched_node_link'; +/** + * Maximum allowed size in bytes, before the attribute is considered large + * and filtered out of the graph. + */ +export const LIMIT_ATTR_SIZE = 1024; +// Separator between the source and the destination name of the edge. +export const EDGE_KEY_DELIM = '--'; +export enum GraphType { + FULL, + EMBEDDED, + META, + SERIES, + CORE, + SHADOW, + BRIDGE, + EDGE, +} +export enum NodeType { + META = 0, + OP = 1, + SERIES, + BRIDGE, + ELLIPSIS, + API_LIST = 9, +} +/** Indicates if a node is to be included in the main graph when rendered. */ +export enum InclusionType { + INCLUDE, + EXCLUDE, + UNSPECIFIED, +} +/** Indicates if a series is to be grouped in the graph when rendered. */ +export enum SeriesGroupingType { + GROUP, + UNGROUP, +} +/** Attribute key reserved for the shapes of the output tensors. */ +const OUTPUT_SHAPES_KEY = '_output_shapes'; +/** Attribute key reserved for the XLA cluster that an op runs on. */ +const _XLA_CLUSTER_KEY = '_XlaCluster'; +/** + * A BaseEdge is the label object (in the graphlib sense) for an edge in the + * original, full graph produced after parsing. Subsequent graphs, like those + * which belong to Metanodes, should not use BaseEdge objects, but instead + * contain Metaedges (which in turn may contain any number of BaseEdges). + */ +export interface BaseEdge { + isReferenceEdge: boolean; + /** The index of the output tensor of the source node. */ + outputTensorKey: string; + attr?: { + [key: string]: any; + }; + v?: string; + w?: string; +} +/** + * A SlimGraph is inspired by graphlib.Graph, but having only the functionality + * that we need. + */ +export class SlimGraph { + nodes: { + [nodeName: string]: OpNode; + }; + metaNodes: { + [nodeName: string]: Metanode; + }; + edges: BaseEdge[]; + constructor() { + this.nodes = {}; + this.metaNodes = {}; + this.edges = []; + } +} +export interface NormalizedInput { + name: string; + /** The index of the output tensor of the source node. */ + outputTensorKey: string; +} +export interface BuildParams { + enableEmbedding: boolean; + inEmbeddingTypes: string[]; + outEmbeddingTypes: string[]; + refEdges: { + [inputEdge: string]: boolean; + }; +} +/** + * The most basic information about a node in the hierarchical graph. + */ +export interface Node { + /** The name of the node, used frequently to look up nodes by name. */ + name: string; + /** Which type of node this is. */ + type: NodeType; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + suggestions: { + [key: string]: string; + }; + /** + * Whether this node is a type that may contain other nodes. Those types + * should extend from GroupNode. + * + * For an OpNode, isGroupNode will be false, even though it may have + * embeddings. These embedding Nodes will have their parentNode set to the + * OpNode. However, embeddings are later rendered as annotations, not as + * children to be made visible on expansion (like a Metanode or SeriesNode). + */ + isGroupNode: boolean; + /** + * The number of nodes this node represents. For OpNodes, this will be 1, and + * for GroupNodes it will be a count of the total number of descendents it + * contains. + */ + cardinality: number; + /** + * The Node which is this Node's parent. This is of type Node and not + * GroupNode because of embeddings, which will have a parent OpNode. + */ + parentNode: Node; + /** Runtime execution stats for this node, if available */ + stats: NodeStats; + /** If the node is to be included or excluded from the main graph when + * rendered. Defaults to UNSPECIFIED, which means that the rendering + * algorithm determines if it will be included or not. Then can be set to + * INCLUDE or EXCLUDE manually by the user. + */ + include: InclusionType; + /** + * Node attributes specify customizable visual aspects of a node and + * application-specific metadata associated with a node. The name + * 'nodeAttributes' is meant to avoid naming-conflicts with the 'attr' in + * subclasses of Node. + */ + nodeAttributes: { + [key: string]: any; + }; +} +export type TensorShape = number[]; +export interface OpNode extends Node { + op: string; + // The device on which the op ran. Null if it is unknown. + device: string; + attr: { + key: string; + value: any; + }[]; + inputs: NormalizedInput[]; + inEmbeddings: OpNode[]; + outEmbeddings: OpNode[]; + // The name of the SeriesNode that can contain this node in its series. + // If there is no such node, then this is null. + owningSeries: string; + /** + * Object mapping output channel string to tensor shapes. The output channel + * is a string rather than a number because within TensorFlow functions, an + * output may be a cross between an output variable and a number (combined + * with a colon) such as "foo:2" rather than just a number alone. + * + * Each tensor shape is an array of numbers, or null. Details: + * - null means unknown rank, and therefore entire shape is unknown. + * - [4, 2, 1] means rank-3 tensor of size 4x2x1. + * - [] means a scalar (rank-0 tensor). + * - [1] means rank-1 tensor of size 1 (not the same as scalar). + * - [5, -1, 3] means rank-3 tensor of shape is 5x?x3. The size + * of the middle dimension is unknown (encoded as -1). + */ + outputShapes: { + [key: string]: TensorShape; + }; + // The XLA Cluster on which the op ran. Null if it is unknown. + xlaCluster: string; + // Whether op is compatible with its assigned device. Currently, if an op + // is not specified a device, the device is defaulted to the TPU. + // Furthermore, all ops are considered compatible for CPU and GPU devices, + // while a whitelist of compatible ops are specified for the TPU. + // Reference: opValid func in op.ts. + compatible: boolean; + // This field is only defined if the op node represents an input_arg to a + // library function. It is the index of the input_arg. + functionInputIndex: number; + // This field is only defined if the op node represents an output_arg of a + // library function. It is the index of the output_arg. + functionOutputIndex: number; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + stackData: []; + matchedNodeLink: []; + suggestions: { + [key: string]: string; + }; +} +export interface BridgeNode extends Node { + /** + * Whether this bridge node represents edges coming into its parent node. + */ + inbound: boolean; +} +/** + * A node that is used when there are more than the maximum number of allowed + * annotations hanging off of a node. This node represents an ellipsis + * annotation, indicating a number of additional annotations. + */ +export interface EllipsisNode extends Node { + /** + * The number of nodes this ellipsis represents. + */ + numMoreNodes: number; + /** + * Sets the number of nodes this ellipsis represents and changes the node + * name accordingly. + */ + setNumMoreNodes(numNodes: number); +} +export interface GroupNode extends Node { + /** + * The metagraph contains nodes and metaedges between the immediate children + * of this group. The node label objects may be other GroupNodes (like + * SeriesNodes and Metanodes) or individual OpNodes. All edge label objects + * are Metaedges, each of which contains references to the original + * BaseEdge(s) from which it was created. + */ + metagraph: graphlib.Graph; + /** + * The bridgegraph contains only edges which link immediate children of this + * group with nodes outside of the metagraph. As in the metagraph, all edge + * label objects are Metaedges which contain references to the original + * BaseEdge(s) that contribute to it. + * + * For a Metaedge in the bridgegraph, its external endpoint will be the same + * as the metagraph edge from which it came. This is most easily explained + * by example. + * + * Consider an original graph that contains a BaseEdge A/B/C->Z/Y/X. + * + * +-------+ (BaseEdge) +-------+ + * | A/B/C |>----------------->| Z/Y/X | + * +-------+ +-------+ + * + * When we construct the Root's metagraph, it will contain nodes for A and Z, + * and a Metaedge A->Z. The A->Z Metaedge will contain the original BaseEdge + * A/B/C->Z/Y/X in its baseEdgeGraph. The Root's bridgegraph will always be + * empty. + * + * +---+ (Root.metagraph edge) +---+ + * | A |>--------------------------->| Z | + * +---+ +---+ + * + * Now consider the Metanode A. Its metagraph will contain a Metanode for A/B + * and no edges. A's bridgegraph will have one Metaedge from A/B->Z, which + * was derived from the Root's Metaedge A->Z. That Metaedge will contain the + * original BaseEdge in its baseEdgeGraph. + * + * +---------+ + * | A | + * | +---+ | (A.bridgegraph edge) +---+ + * | | B |>---------------------------->| Z | + * | +---+ | +---+ + * +---------+ + * + * Finally, consider the Metanode A/B. Its metagraph will contain a Metanode + * for A/B/C and again no edges. A/B's bridgegraph will have one Metaedge + * from A/B/C->Z, which was derived from A's bridgegraph Metaedge A/B->Z. + * As before, the A/B/C->Z Metaedge will contain the original BaseEdge in its + * baseEdgeGraph. + * + * +---------------+ + * | A | + * | +---------+ | + * | | B | | + * | | +---+ | | (A/B.bridgegraph edge) +---+ + * | | | C |>----------------------------------->| Z | + * | | +---+ | | +---+ + * | +---------+ | + * +---------------+ + * + * Likewise, under the Metanode Z and Z/Y, to compute the bridgegraph, we'll + * end up with Metaedges A->Z/Y and A->Z/Y/X respectively. So the original + * BaseEdge A/B/C->Z/Y/X becomes four different Metaedges in four different + * bridgegraphs: + * + * + A/B->Z in GroupNode A's bridgegraph, + * + A/B/C->Z in GroupNode A/B's bridgegraph, + * + A->Z/Y in GroupNode Z's bridgegraph, and + * + A->Z/Y/X in GroupNode Z/Y's bridgegraph. + * + * Considering any BaseEdge then, if N is the number of path segments in the + * source and M is the number of path segments in the destination, then the + * total number of bridgegraph edges you could create would be (N-1)(M-1). + * + * For this reason, it is computationally expensive to generate all the + * bridgegraphs for all the Metanodes, and instead they should be computed + * on demand as needed. + */ + bridgegraph: graphlib.Graph; + /** + * Stores how many times each device name appears in its children + * op nodes. Used to color group nodes by devices. + */ + deviceHistogram: { + [device: string]: number; + }; + /** + * Stores how many times each XLA cluster name appears in its children + * op nodes. Used to color group nodes by XLA clusters. + */ + xlaClusterHistogram: { + [device: string]: number; + }; + /** + * Stores how many ops in sub-graph were compatible and how many are + * incompatible. + */ + compatibilityHistogram: { + compatible: number; + incompatible: number; + }; + /** + * Flag indicating whether this GroupNode's metagraph contains any edges that + * are not control edges. Used to quickly determine how to draw a collapsed + * series (vertically or horizontally). + */ + hasNonControlEdges: boolean; +} +export interface Metanode extends GroupNode { + depth: number; + templateId: string; + opHistogram: { + [op: string]: number; + }; + attr: { + key: string; + value: any; + }[]; + // The name of the function this metanode is associated with if any. + associatedFunction: string; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + stackData: []; + matchedNodeLink: []; + suggestions: { + [key: string]: string; + }; + getFirstChild(): GroupNode | OpNode; + getRootOp(): OpNode; + /** Return name of all leaves inside a metanode. */ + leaves(): string[]; +} +export interface SeriesNode extends GroupNode { + hasLoop: boolean; + prefix: string; + suffix: string; + clusterId: number; + ids: number[]; + parent: string; +} +export class EllipsisNodeImpl implements EllipsisNode { + name: string; + numMoreNodes: number; + stats: NodeStats; + type: NodeType; + isGroupNode: boolean; + cardinality: number; + parentNode: Node; + include: InclusionType; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + suggestions: { + [key: string]: string; + }; + nodeAttributes: { + [key: string]: any; + }; + /** + * Constructs a new ellipsis annotation node. + * + * @param numNodes The number of additional annotations this node represents. + */ + constructor(numNodes: number) { + this.type = NodeType.ELLIPSIS; + this.isGroupNode = false; + this.cardinality = 1; + this.parentNode = null!; + this.stats = null!; + this.setNumMoreNodes(numNodes); + this.include = InclusionType.UNSPECIFIED; + } + setNumMoreNodes(numNodes: number) { + this.numMoreNodes = numNodes; + this.name = '... ' + numNodes + ' more'; + } +} +/** + * A label object for nodes in the full graph and leaf nodes in the render + * graph. + */ +export class OpNodeImpl implements OpNode { + name: string; + op: string; + device: string; + stats: NodeStats; + attr: { + key: string; + value: any; + }[]; + inputs: NormalizedInput[]; + type: NodeType; + isGroupNode: boolean; + cardinality: number; + inEmbeddings: OpNode[]; + outEmbeddings: OpNode[]; + parentNode: Node; + include: InclusionType; + owningSeries: string; + outputShapes: { + [key: string]: TensorShape; + }; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + stackData: []; + matchedNodeLink: []; + suggestions: { + [key: string]: string; + }; + nodeAttributes: { + [key: string]: any; + }; + xlaCluster: string; + compatible: boolean; + // This field is only defined if the op node represents an input_arg to a + // library function. It is the index of the input_arg. + functionInputIndex: number; + // This field is only defined if the op node represents an output_arg of a + // library function. It is the index of the output_arg. + functionOutputIndex: number; + /** + * Constructs a new Op node. + * + * @param rawNode The raw node. + */ + constructor(rawNode: tf_graph_proto.NodeDef) { + this.op = rawNode.op; + this.name = rawNode.name; + this.device = rawNode.device; + this.attr = rawNode.attr; + // An array of normalized inputs that denote the incoming edges to + // the current node. Each input contains the normalized name of the + // source node, whether it has a number part and whether it is a + // control dependency. + this.inputs = normalizeInputs(rawNode.input); + this.outputShapes = extractOutputShapes(rawNode.attr); + this.xlaCluster = extractXlaCluster(rawNode.attr)!; + this.compatible = false; + // additional properties + this.type = NodeType.OP; + this.isGroupNode = false; + this.cardinality = 1; + this.inEmbeddings = []; + this.outEmbeddings = []; + this.parentNode = null!; + this.include = InclusionType.UNSPECIFIED; + this.owningSeries = null!; + this.inputData = rawNode.input_data; + this.outputData = rawNode.output_data; + this.suggestions = rawNode.suggestions; + this.stackData = rawNode.stack_info; + this.matchedNodeLink = rawNode.matched_node_link; + this.nodeAttributes = {}; + } +} +export function createMetanode(name: string, opt = {}): Metanode { + return new MetanodeImpl(name, opt); +} +/** + * Joins the information from the stats file (memory, compute time) with the + * graph information. + */ +export function joinStatsInfoWithGraph( + graph: SlimGraph, + stats: tf_graph_proto.StepStats, + devicesForStats?: { + [device: string]: boolean; + }, +): void { + // Reset stats for each node. + _.each(graph.nodes, (node) => { + node.stats = null!; + }); + _.each(stats.dev_stats, (devStats) => { + // Ignore devices that are not selected. + if (devicesForStats && !devicesForStats[devStats.device]) { + return; + } + _.each(devStats.node_stats, (nodeStats) => { + // Lookup the node in the graph by its original name, e.g. A/B. If not + // found, lookup by the rewritten name A/B/(B) in case the name is both + // a namespace and a node name. + let nodeName = nodeStats.node_name in graph.nodes ? nodeStats.node_name : getStrictName(nodeStats.node_name); + // Couldn't find a matching node. + if (!(nodeName in graph.nodes)) { + return; + } + // Compute the total bytes used. + let totalBytes = 0; + if (nodeStats.memory) { + _.each(nodeStats.memory, (alloc) => { + if (alloc.total_bytes) { + if (alloc.total_bytes > 0) { + totalBytes += Number(alloc.total_bytes); + } else { + /* tslint:disable */ + console.log('ignoring negative memory allocation for ' + nodeName); + /* tslint:enable */ + } + } + }); + } + let outputSize: number[][] = null!; + if (nodeStats.output) { + outputSize = _.map(nodeStats.output, (output) => { + return _.map(output.tensor_description.shape.dim, (dim) => Number(dim.size)); + }); + } + graph.nodes[nodeName].device = devStats.device; + if (graph.nodes[nodeName].stats == null) { + graph.nodes[nodeName].stats = new NodeStats(outputSize); + } + graph.nodes[nodeName].stats.addBytesAllocation(totalBytes); + if (nodeStats.all_end_rel_micros) { + if (nodeStats.all_end_rel_micros > 0) { + graph.nodes[nodeName].stats.addExecutionTime( + nodeStats.all_start_micros, + nodeStats.all_start_micros + nodeStats.all_end_rel_micros, + ); + } else { + /* tslint:disable */ + console.log('ignoring negative runtime for ' + nodeName); + /* tslint:enable */ + } + } + }); + }); +} +export class MetanodeImpl implements Metanode { + name: string; + stats: NodeStats; + type: NodeType; + depth: number; + isGroupNode: boolean; + cardinality: number; + metagraph: graphlib.Graph; + bridgegraph: graphlib.Graph; + templateId: string; + opHistogram: { + [op: string]: number; + }; + deviceHistogram: { + [op: string]: number; + }; + xlaClusterHistogram: { + [op: string]: number; + }; + compatibilityHistogram: { + compatible: number; + incompatible: number; + }; + parentNode: Node; + hasNonControlEdges: boolean; + include: InclusionType; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + stackData: []; + matchedNodeLink: []; + suggestions: { + [key: string]: string; + }; + nodeAttributes: { + [key: string]: any; + }; + associatedFunction: string; + attr: { + key: string; + value: any; + }[]; + /** A label object for meta-nodes in the graph hierarchy */ + constructor(name: string, opt = {}) { + this.name = name; + this.type = NodeType.META; + /** number of levels under this group */ + this.depth = 1; + this.isGroupNode = true; + /** # of leaf nodes (including embedded ones) */ + this.cardinality = 0; + /** graph contains metanodes, nodes, edges + * and metaedges for main items within this metanode + */ + this.metagraph = createGraph(name, GraphType.META, opt); + /** bridgegraph must be constructed lazily-see hierarchy.getBridgegraph() */ + this.bridgegraph = null!; + /** + * A dictionary that count ops type of nodes in this metanode + * (op type => count). + */ + this.opHistogram = {}; + this.deviceHistogram = {}; + this.xlaClusterHistogram = {}; + this.compatibilityHistogram = { compatible: 0, incompatible: 0 }; + /** unique id for a metanode of similar subgraph */ + this.templateId = null!; + /** Metanode which contains this node, if any */ + this.parentNode = null!; + this.hasNonControlEdges = false; + this.include = InclusionType.UNSPECIFIED; + this.associatedFunction = ''; + this.attr = []; + this.inputData = {}; + this.outputData = {}; + this.stackData = []; + this.matchedNodeLink = []; + this.suggestions = {}; + this.nodeAttributes = {}; + } + getFirstChild(): GroupNode | OpNode { + return this.metagraph.node(this.metagraph.nodes()[0]) as any; + } + /** + * Returns the op node associated with the metanode. + * For example, if the metanode is 'sgd', the associated + * op node is sgd/(sgd). + */ + getRootOp(): OpNode { + let nameSplit = this.name.split('/'); + let rootOpName = this.name + '/(' + nameSplit[nameSplit.length - 1] + ')'; + return this.metagraph.node(rootOpName) as any; + } + /** + * Return an array of the names of all the leaves (non-GroupNodes) inside + * this metanode. This performs a breadth-first search of the tree, so + * immediate child leaves will appear earlier in the output array than + * descendant leaves. + */ + leaves(): string[] { + let leaves: string[] = []; + let queue = [this]; + let metagraph; // Defined here due to a limitation of ES6->5 compilation. + while (queue.length) { + let node = queue.shift(); + if (node?.isGroupNode) { + metagraph = (node).metagraph; + _.each(metagraph.nodes(), (name) => queue.push(metagraph.node(name))); + } else { + leaves.push(node?.name!); + } + } + return leaves; + } +} +export interface Metaedge { + /** + * Stores the original BaseEdges represented by this Metaedge. + */ + baseEdgeList: BaseEdge[]; + /** + * Whether this edge represents a relationship that is inbound (or outbound) + * to the object which contains this information. For example, in a Metanode's + * bridgegraph, each edge connects an immediate child to something outside + * the Metanode. If the destination of the edge is inside the Metanode, then + * its inbound property should be true. If the destination is outside the + * Metanode, then its inbound property should be false. + * + * The property is optional because not all edges can be described as + * inbound/outbound. For example, in a Metanode's metagraph, all of the edges + * connect immediate children of the Metanode. None should have an inbound + * property, or they should be null/undefined. + */ + inbound?: boolean; + /** + * Number of regular edges (not control dependency edges). + */ + numRegularEdges: number; + /** + * Number of reference edges, which is an edge to an operation + * that takes a reference to its input and changes its value. + */ + numRefEdges: number; + /** + * Total size (number of units) of all the tensors flowing through this edge. + */ + totalSize: number; + addBaseEdge(edge: BaseEdge, h: Hierarchy): void; + v?: string; + w?: string; +} +export function createMetaedge(v: string, w: string): Metaedge { + return new MetaedgeImpl(v, w); +} +/** + * A label object for edges between metanodes of subgraphs in the render graph. + */ +export class MetaedgeImpl implements Metaedge { + v: string; + w: string; + baseEdgeList: BaseEdge[]; + inbound: boolean; + numRegularEdges: number; + numRefEdges: number; + totalSize: number; + constructor(v: string, w: string) { + this.v = v; + this.w = w; + this.baseEdgeList = []; + this.inbound = null!; + this.numRegularEdges = 0; + this.numRefEdges = 0; + this.totalSize = 0; + } + addBaseEdge(edge: BaseEdge, h: Hierarchy): void { + this.baseEdgeList.push(edge); + this.numRegularEdges += 1; + if (edge.isReferenceEdge) { + this.numRefEdges += 1; + } + // Compute the size of the tensor flowing through this + // base edge. + this.totalSize += (JSON.parse(edge.outputTensorKey) as number[]).reduce((accumulated, currSize) => { + if (currSize === -1) { + currSize = 1; + } + return accumulated * currSize; + }, 1); + h.maxMetaEdgeSize = Math.max(h.maxMetaEdgeSize, this.totalSize); + } + private static computeSizeOfEdge(edge: BaseEdge, h: Hierarchy): number { + let opNode = h.node(edge.v!); + if (!opNode.outputShapes) { + // No shape information. Asssume a single number. This gives + // a lower bound for the total size. + return 1; + } + h.hasShapeInfo = true; + // Sum the sizes of all output tensors. + // TODO(stephanwlee): Use Object.values after es2017. + const values = Object.keys(opNode.outputShapes) + .map((k) => opNode.outputShapes[k]) + .map((shape: number[]) => { + // If the shape is unknown, treat it as 1 when computing + // total size. This gives a lower bound for the total size. + if (shape == null) { + return 1; + } + // Multiply all shapes to get the total size of the tensor. + // E.g. The total size of [4, 2, 1] is 4 * 2 * 1. + return shape.reduce((accumulated, currSize) => { + // If this particular dimension is unknown, treat + // it as 1 when computing total size. This gives a lower bound + // for the total size. + if (currSize === -1) { + currSize = 1; + } + return accumulated * currSize; + }, 1); + }); + return _.sum(values); + } +} +export function createSeriesNode( + prefix: string, + suffix: string, + parent: string, + clusterId: number, + name: string, + graphOptions: LabeledGraphOptions, +): SeriesNode { + return new SeriesNodeImpl(prefix, suffix, parent, clusterId, name, graphOptions); +} +export function getSeriesNodeName( + prefix: string, + suffix: string, + parent: string, + startId?: number, + endId?: number, +): string { + let numRepresentation = + typeof startId !== 'undefined' && typeof endId !== 'undefined' ? '[' + startId + '-' + endId + ']' : '#'; + let pattern = prefix + numRepresentation + suffix; + return (parent ? parent + '/' : '') + pattern; +} +class SeriesNodeImpl implements SeriesNode { + name: string; + type: NodeType; + stats: NodeStats; + hasLoop: boolean; + prefix: string; + suffix: string; + clusterId: number; + ids: number[]; + parent: string; + isGroupNode: boolean; + cardinality: number; + metagraph: graphlib.Graph; + bridgegraph: graphlib.Graph; + parentNode: Node; + deviceHistogram: { + [op: string]: number; + }; + xlaClusterHistogram: { + [op: string]: number; + }; + compatibilityHistogram: { + compatible: number; + incompatible: number; + }; + hasNonControlEdges: boolean; + include: InclusionType; + inputData: { + [key: string]: any; + }; + outputData: { + [key: string]: any; + }; + stackData: []; + matchedNodeLink: []; + suggestions: { + [key: string]: string; + }; + nodeAttributes: { + [key: string]: any; + }; + constructor( + prefix: string, + suffix: string, + parent: string, + clusterId: number, + name: string, + graphOptions: LabeledGraphOptions, + ) { + this.name = name || getSeriesNodeName(prefix, suffix, parent); + this.type = NodeType.SERIES; + this.hasLoop = false; + this.prefix = prefix; + this.suffix = suffix; + this.clusterId = clusterId; + this.ids = []; + this.parent = parent; + this.isGroupNode = true; + this.cardinality = 0; + this.metagraph = createGraph(name, GraphType.SERIES, graphOptions); + // bridgegraph must be constructed lazily-see hierarchy.getBridgegraph() + this.bridgegraph = null!; + this.parentNode = null!; + this.deviceHistogram = {}; + this.xlaClusterHistogram = {}; + this.compatibilityHistogram = { compatible: 0, incompatible: 0 }; + this.hasNonControlEdges = false; + this.include = InclusionType.UNSPECIFIED; + } +} +/** + * Extracts the shapes of the output tensors from the attr property in the + * node proto. + */ +// tslint:disable-next-line:no-any +function extractOutputShapes( + attr: Array<{ + key: string; + value: any; + }>, +): { + [key: string]: TensorShape; +} { + let result = null; + // We don't know anything about the output tensors. + if (!attr) { + return null!; + } + for (let i = 0; i < attr.length; i++) { + let { key, value } = attr[i]; + if (key === OUTPUT_SHAPES_KEY) { + if (!value.list || !value.list.shape) { + // The OUTPUT_SHAPES_KEY lacks a value. We know nothing about the shape. + return null!; + } + // Map all output tensors into array of numbers denoting their shape. + let result = value.list.shape.map((shape) => { + if (shape.unknown_rank) { + // This output tensor is of unknown rank. We don't know if it is a + // scalar, or a tensor, or of what shape it is. + return null!; + } + if (shape.dim == null || (shape.dim.length === 1 && shape.dim[0].size == null)) { + // This output tensor is a scalar. + return []; + } + // This output tensor has a known rank. Map each dimension size + // into a number. + return shape.dim.map((dim) => { + // Size can be -1 if this particular dimension is unknown. + // If we actually have a 0-dimension tensor `dim.size` returns null, + // so we default to 0 in this case to avoid upstream null-handling + // issues. + return dim.size || 0; + }); + }); + // Since we already processed it, remove the entry from the attribute + // list (saves memory). + attr.splice(i, 1); + return result; + } + } + // We didn't find OUTPUT_SHAPES_KEY in attributes, so we don't know anything + // about the output tensors. + return null!; +} +/** + * Extracts the XLA Cluster that an op runs on from the attrs of the OpNode. + * @param attr The attr property. + * @return A string that is the name of the cluster. Or null if it could not be + * determined. + */ +// tslint:disable-next-line:no-any +function extractXlaCluster( + attr: Array<{ + key: string; + value: any; + }>, +): string | null { + if (!attr) { + return null; + } + // Find the attribute for XLA cluster if there is one. + for (let i = 0; i < attr.length; i++) { + if (attr[i].key === _XLA_CLUSTER_KEY) { + return attr[i].value['s'] || null; + } + } + return null; +} + +/** + * Matches node name that encodes output tensor name and/or its index. + * - : + * - :: + */ +const INPUT_NAME_PART_MATCHER = /^([^:]+):((\w+:|)\d+)$/; + +/** + * Normalizes the inputs and extracts associated metadata: + * 1) Inputs can contain a colon followed by a suffix of characters. + * That suffix may be a single number (e.g. inputName:1) or several word + * characters separated from a number by a colon (e.g. inputName:foo:1). The + * latter case is used to denote inputs and outputs of functions. + * 2) Control dependency inputs contain caret at the beginning and we + * remove this and annotate the edge as a control dependency. + * @param inputs Array of unnormalized names of input nodes. + */ +function normalizeInputs(inputs: string[] | undefined): NormalizedInput[] { + const normalizedInputs: NormalizedInput[] = []; + let lastName: string | null = null; + for (let inputName of inputs || []) { + let name = inputName; + let outputTensorKey = '0'; + const match = inputName.includes(':') && inputName.match(INPUT_NAME_PART_MATCHER); + if (match) { + // The output string consists of optionally several characters and a number + // separated by a colon. + name = match[1]; + outputTensorKey = match[2]; + } + + if (lastName !== name) { + lastName = name; + normalizedInputs.push({ + name: name, + outputTensorKey: outputTensorKey, + }); + } + } + return normalizedInputs; +} +function addEdgeToGraph( + graph: SlimGraph, + inputName: string, + outputNode: OpNode, + input: NormalizedInput, + params: BuildParams, + index: number, +) { + // Don't allow loops in the graph. + if (inputName === outputNode.name) { + return; + } + // Check if this op type and input number corresponds to a + // reference edge using the refEdges dictionary in the params. + let isRefEdge = params.refEdges[outputNode.op + ' ' + index] === true; + graph.edges.push({ + v: inputName, + w: outputNode.name, + outputTensorKey: input.outputTensorKey, + isReferenceEdge: isRefEdge, + }); +} +function addEdgeToGraphByAttr(graph: SlimGraph, node: OpNode | Metanode, edgeInfo: any) { + const { shape, source, target, ...attr } = edgeInfo; + // Don't allow loops in the graph. + if ((!target && !source) || edgeInfo.target === node.name || edgeInfo.source === node.name) { + return; + } + let outputTensorKey = shape; + if (!shape || shape === 'N/A') { + outputTensorKey = '[]'; + } + const edge = { + v: !!target ? node.name : source, + w: target || node.name, + outputTensorKey, + isReferenceEdge: false, + attr, + }; + if (!graph.edges.find((item) => JSON.stringify(item) === JSON.stringify(edge))) { + graph.edges.push(edge); + } +} +export const DefaultBuildParams: BuildParams = { + enableEmbedding: true, + inEmbeddingTypes: ['Const'], + outEmbeddingTypes: ['^[a-zA-Z]+Summary$'], + // This is the whitelist of inputs on op types that are considered + // reference edges. "Assign 0" indicates that the first input to + // an OpNode with operation type "Assign" is a reference edge. + refEdges: { + 'Assign 0': true, + 'AssignAdd 0': true, + 'AssignSub 0': true, + 'assign 0': true, + 'assign_add 0': true, + 'assign_sub 0': true, + 'count_up_to 0': true, + 'ScatterAdd 0': true, + 'ScatterSub 0': true, + 'ScatterUpdate 0': true, + 'scatter_add 0': true, + 'scatter_sub 0': true, + 'scatter_update 0': true, + }, +}; +export function build( + graphDef: tf_graph_proto.GraphDef, + params: BuildParams, + tracker?: ProgressTracker, +): Promise { + /** + * A dictionary that maps each in-embedding node name to the node + * object. + */ + let inEmbedding: { + [nodeName: string]: OpNode; + } = {}; + /** + * A dictionary that maps each out-embedding node name to the node + * object. + */ + let outEmbedding: { + [nodeName: string]: OpNode; + } = {}; + /** + * A dictionary that maps each node name to an array of the node's + * out-embedding node label objects. + */ + let outEmbeddings: { + [inputName: string]: OpNode[]; + } = {}; + let isInEmbeddedPred = getEmbedPredicate(params.inEmbeddingTypes); + let isOutEmbeddedPred = getEmbedPredicate(params.outEmbeddingTypes); + let embeddingNodeNames: string[] = []; + let rawNodes = graphDef.node; + /** + * A list of all the non-embedding node names which appear in the processed + * list of raw nodes. Here we pre-allocate enough room for all the rawNodes, + * even though there will some number of embeddings. The excess array length + * is spliced off later. + * + * Experimentation shows that around 30% of the array will go unused, and + * even for very large networks that amounts to less than 10k spaces. + */ + let nodeNames = new Array(rawNodes.length); + return tf_graph_util + .runAsyncTask( + 'Normalizing names', + 30, + () => { + let opNodes = new Array(rawNodes.length); + let index = 0; + const processRawNode = (rawNode: tf_graph_proto.NodeDef) => { + if (!rawNode.isLeaf) { + let metaNode = new MetanodeImpl(rawNode.name); + metaNode.attr = rawNode.attr; + metaNode.nodeAttributes['_order'] = index; + if (rawNode.matched_node_link && rawNode.matched_node_link.length > 0) { + metaNode.nodeAttributes['_linked_node'] = rawNode.matched_node_link; + } + metaNode.inputData = rawNode.input_data; + metaNode.outputData = rawNode.output_data; + metaNode.stackData = rawNode.stack_info; + metaNode.matchedNodeLink = rawNode.matched_node_link; + metaNode.suggestions = rawNode.suggestions; + if (Number(rawNode.node_type) == 1) { + metaNode.type == 0; + } else { + metaNode.type = Number(rawNode.node_type); + } + opNodes[index] = metaNode; + nodeNames[index] = metaNode.name; + index++; + return metaNode; + } else { + let opNode = new OpNodeImpl(rawNode); + opNode.nodeAttributes['_order'] = index; + if (rawNode.matched_node_link && rawNode.matched_node_link.length > 0) { + opNode.nodeAttributes['_linked_node'] = rawNode.matched_node_link; + } + if (isInEmbeddedPred(opNode)) { + embeddingNodeNames.push(opNode.name); + inEmbedding[opNode.name] = opNode; + return opNode; + } + if (isOutEmbeddedPred(opNode)) { + embeddingNodeNames.push(opNode.name); + outEmbedding[opNode.name] = opNode; + _.each(opNode.inputs, (input) => { + let inputName = input.name; + outEmbeddings[inputName] = outEmbeddings[inputName] || []; + outEmbeddings[inputName].push(opNode); + }); + return opNode; + } + // The node is not an embedding, so add it to the names and nodes + // lists. + opNodes[index] = opNode; + nodeNames[index] = opNode.name; + index++; + return opNode; + } + }; + _.each(rawNodes, processRawNode); + opNodes.splice(index); + nodeNames.splice(index); + return opNodes; + }, + tracker, + tb_debug.GraphDebugEventId.NORMALIZING_NAMES, + ) + .then((opNodes) => { + // Create the graph data structure from the graphlib library. + return tf_graph_util.runAsyncTask( + 'Building the data structure', + 70, + () => { + let normalizedNameDict = mapStrictHierarchy(nodeNames, embeddingNodeNames); + let graph = new SlimGraph(); + // Add the nodes to the graph. + _.each(opNodes, (opNode) => { + if (opNode instanceof OpNodeImpl) { + let normalizedName = normalizedNameDict[opNode.name] || opNode.name; + graph.nodes[normalizedName] = opNode; + // Check if the node has out-embeddings. If yes, add them to the + // node. + if (opNode.name in outEmbeddings) { + opNode.outEmbeddings = outEmbeddings[opNode.name]; + // Normalize the names of the out-embeddings. + _.each(opNode.outEmbeddings, (node) => { + node.name = normalizedNameDict[node.name] || node.name; + }); + } + // Update the name of the node. + opNode.name = normalizedName; + } else { + graph.metaNodes[opNode.name] = opNode as MetanodeImpl; + } + }); + // Visit each node's inputs to add the edges to the graph. If the + // input + // is an in-embedding, then add it to the node's in-embeddings + // instead. + _.each(opNodes, (opNode) => { + _.each(opNode.attr, ({ key, value }) => { + if (key === 'edge_info') { + addEdgeToGraphByAttr(graph, opNode, value); + } + }); + // Removes repeated edge info. + opNode.attr = _.filter(opNode.attr, ({ key, value }) => key !== 'edge_info'); + }); + // Normalize the names of in-embeddings. + _.each(inEmbedding, (node, name) => { + node.name = normalizedNameDict[node.name] || node.name; + }); + return graph; + }, + tracker, + tb_debug.GraphDebugEventId.BUILD_SLIM_GRAPH, + ); + }); +} +/** + * Create a new graphlib.Graph() instance with default parameters + */ +export function createGraph(name: string, type, graphOptions: LabeledGraphOptions = {}): graphlib.Graph { + const graph = new graphlib.Graph({ ...graphOptions, multigraph: true }); + graph.setGraph({ + name: name, + rankdir: graphOptions.rankdir || 'TB', + type: type, + } as any); + return graph; +} +/** + * Create a predicate for checking whether a node should be embedded based on + * the specified types. + */ +function getEmbedPredicate(types: string[]) { + return function (node: OpNode) { + // check types + for (let i = 0; i < types.length; i++) { + let regExp = new RegExp(types[i]); + if (typeof node.op === 'string' && node.op.match(regExp)) { + return true; + } + } + return false; + }; +} +/** + * Returns a strict node name (name => name/(name)) to avoid conflicts + * where the node name is also a namespace. + */ +export function getStrictName(name: string): string { + let parts = name.split(NAMESPACE_DELIM); + return name + NAMESPACE_DELIM + '(' + parts[parts.length - 1] + ')'; +} +/** + * For each op node (embedding or non-embedding), rename it if there is a + * non-embedding node under its namespace. For example, assume node name 'A'. + * If there is a non-embedding node under its namespace (e.g. 'A/B'), 'A' will + * be renamed to 'A/(A)'. Then the namespace 'A' will contain 2 nodes: '(A)' + * and 'B'. If all the nodes under 'A' are embedding nodes (e.g. constant and + * summary), keep 'A' as an Op node and don't create a namespace. + * + * @param nodeNames An array of regular (non-embedding) node names. + * @param embeddingNodeNames An array of embedding node names. + * @return Dictionary object mapping names that need to be renamed to + * new names. + */ +function mapStrictHierarchy( + nodeNames: string[], + embeddingNodeNames: string[], +): { + [oldName: string]: string; +} { + /** Dictionary that maps the old new to the new name */ + let newNameDictionary: { + [oldName: string]: string; + } = {}; + /** Set used to store all namespaces. */ + let namespaceSet: { + [namespace: string]: boolean; + } = {}; + // sort the nodes to make prefix check faster + nodeNames.sort(); + // look for nodes with a prefix a,a/b -> a/(a),a/b + for (let i = 0; i < nodeNames.length - 1; ++i) { + let a = nodeNames[i]; + // Get all the parent namespaces of the current node + // and add them in the namespace set. + _.each(getHierarchicalPath(a).slice(0, -1), (ns) => { + namespaceSet[ns] = true; + }); + for (let j = i + 1; j < nodeNames.length; ++j) { + let b = nodeNames[j]; + if (_.startsWith(b, a)) { + if (b.length > a.length && b.charAt(a.length) === NAMESPACE_DELIM) { + newNameDictionary[a] = getStrictName(a); + break; + } + } else { + break; + } + } + } + // Go through all the embedding node names and rename them in case they + // collide with namespaces. + _.each(embeddingNodeNames, (embeddingName) => { + if (embeddingName in namespaceSet) { + // Rename to follow strict hierarchy. + newNameDictionary[embeddingName] = getStrictName(embeddingName); + } + }); + return newNameDictionary; +} +/** + * Returns a list of the degrees of each node in the graph. + */ +function degreeSequence(graph: graphlib.Graph): number[] { + let degrees = graph.nodes().map(function (name) { + return graph.neighbors(name)?.length!; + }); + degrees.sort(); + return degrees; +} +/** + * Returns if the degree sequence of the two graphs is the same. + */ +export function hasSimilarDegreeSequence(graph1: graphlib.Graph, graph2: graphlib.Graph): boolean { + let dg1 = degreeSequence(graph1); + let dg2 = degreeSequence(graph2); + for (let i = 0; i < dg1.length; i++) { + if (dg1[i] !== dg2[i]) { + return false; + } + } + return true; +} +/** + * Returns the hierarchical path of the current node, based on the node's name. + * For example, if the name is 'a/b/c', the returned path is + * ['a', 'a/b', 'a/b/c']. + */ +export function getHierarchicalPath( + name: string, + seriesNames?: { + [name: string]: string; + }, +): string[] { + let path: string[] = []; + let i = name.indexOf(NAMESPACE_DELIM); + // Push all parent portions of the path. + while (i >= 0) { + path.push(name.substring(0, i)); + i = name.indexOf(NAMESPACE_DELIM, i + 1); + } + // If the node's path is under a series, then add the series node name to the + // hierarchical path as the parent of the leaf. + if (seriesNames) { + let seriesName = seriesNames[name]; + if (seriesName) { + path.push(seriesName); + } + } + // Push the leaf of the path. + path.push(name); + return path; +} +/** + * Returns the string for the node inclusion toggle button, dependant + * on the provided current InclusionType. + */ +export function getIncludeNodeButtonString(include: InclusionType) { + if (include === InclusionType.EXCLUDE) { + return 'Add to main graph'; + } else { + return 'Remove from main graph'; + } +} +/** + * Returns the string for the series node grouping toggle button, dependant + * on the provided current SeriesGroupingType. + */ +export function getGroupSeriesNodeButtonString(group: SeriesGroupingType) { + if (group === SeriesGroupingType.GROUP) { + return 'Ungroup this series of nodes'; + } else { + return 'Group this series of nodes'; + } +} + +export interface Edges { + control: Metaedge[]; + regular: Metaedge[]; +} +/** + * Class used to store data on library functions. This specifically stores data + * on the library function, not individual calls to those functions. + */ +export interface LibraryFunctionData { + // The metanode representing this function in the library scene group. + node: Metanode; + // A list of nodes that represent calls to this library function. + usages: Node[]; +} + +/** + * An extended variant of the options object for `graphlib.Graph`, used + * to configure a `graphlib.Graph` at its creation. + * + * Dagre's constructor has an `opts` object as a parameter, let's call it + * 'GraphCtorOptions'. The Graph's `setGraph()` has a `label` parameter, + * let's call it `LabelOptions`. + * + * Since both are configured when a `graphlib.Graph` is first initialized, + * TensorBoard's Graph code passes around this hybrid object which includes + * properties from both `GraphCtorOptions` (compound) and `LabelOptions` + * (rankdir). + */ +export type LabeledGraphOptions = { + compound?: boolean; + rankdir?: string; + multigraph?: boolean; +}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/hierarchy.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/hierarchy.ts new file mode 100644 index 0000000000000000000000000000000000000000..1cbd1220b5de6c501ca373d5f664537b4d4a217b --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/hierarchy.ts @@ -0,0 +1,1076 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * Package for the Graph Hierarchy for TensorFlow graph. + */ +import * as d3 from 'd3'; +import { graphlib } from 'dagre'; +import * as _ from 'lodash'; +import * as tb_debug from '../tb_debug'; +import { NodeStats, ProgressTracker } from './common'; +import * as tf_graph from './graph'; +import { + createGraph, + createMetaedge, + createMetanode, + createSeriesNode, + Edges, + getHierarchicalPath, + getSeriesNodeName, + GraphType, + GroupNode, + Metaedge, + MetaedgeImpl, + Metanode, + Node, + NodeType, + OpNode, + ROOT_NAME, + SeriesNode, + SlimGraph, + MetanodeImpl, +} from './graph'; +import * as tf_graph_proto from './proto'; +import * as template from './template'; +import * as tf_graph_util from './util'; + +export enum HierarchyEvent { + /** + * Fired when the templates may have been updated. No event payload attached. + */ + TEMPLATES_UPDATED, +} + +// A map from the name of a series node to its grouping type. +type SeriesGroupMap = Map; + +/** + * Class for the Graph Hierarchy for TensorFlow graph. + */ +export class Hierarchy extends tf_graph_util.Dispatcher { + root: Metanode; + devices: string[]; + /** + * Whether at least one tensor in the graph has shape information. + */ + hasShapeInfo = false; + /** + * The maximum size across all meta edges. Used for scaling thickness. + */ + maxMetaEdgeSize = 1; + orderings: { + [nodeName: string]: { + [childName: string]: number; + }; + }; + /** + * Options passed to dagre for creating the graph. Note that the + * `compound` argument will be overridden to true. + */ + graphOptions: tf_graph.LabeledGraphOptions = {}; + + private verifyTemplate: boolean; + private templates: { + [templateId: string]: string[]; + } | null = null; + private index: { + [nodeName: string]: GroupNode | OpNode; + }; + private readonly seriesGroupMap: SeriesGroupMap; + + constructor(params: HierarchyParams) { + super(); + this.graphOptions.compound = true; + this.graphOptions.rankdir = params.rankDirection; + this.root = createMetanode(ROOT_NAME, this.graphOptions); + this.seriesGroupMap = new Map(params.seriesMap); + this.devices = null!; + this.verifyTemplate = params.verifyTemplate; + /** + * @type {Object} Dictionary object that maps node name to the node + * (could be op-node, metanode, or series-node) + */ + this.index = {}; + this.index[ROOT_NAME] = this.root; + this.orderings = {}; + } + getSeriesGroupType(nodeName: string): tf_graph.SeriesGroupingType { + // If grouping was not specified, assume it should be grouped by default. + return this.seriesGroupMap.get(nodeName) ?? tf_graph.SeriesGroupingType.GROUP; + } + setSeriesGroupType(nodeName: string, groupType: tf_graph.SeriesGroupingType) { + return this.seriesGroupMap.set(nodeName, groupType); + } + buildSeriesGroupMapToggled(nodeName: string): Map { + const newGroupType = + this.getSeriesGroupType(nodeName) === tf_graph.SeriesGroupingType.GROUP + ? tf_graph.SeriesGroupingType.UNGROUP + : tf_graph.SeriesGroupingType.GROUP; + return new Map([...this.seriesGroupMap, [nodeName, newGroupType]]); + } + getNodeMap(): { + [nodeName: string]: GroupNode | OpNode; + } { + return this.index; + } + node(name: string): GroupNode | OpNode { + return this.index[name]; + } + setNode(name: string, node: GroupNode | OpNode): void { + this.index[name] = node; + } + /** + * Given the name of a node in this hierarchy, get its bridgegraph, creating + * it on the fly if necessary. If the node is not a GroupNode, then this + * method returns null. If the provided name does not map to a node in the + * hierarchy, an error will be thrown. + */ + getBridgegraph(nodeName: string): graphlib.Graph { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node in hierarchy: ' + nodeName); + } + if (!('metagraph' in node)) { + return null!; + } + let groupNode = node; + if (groupNode.bridgegraph) { + return groupNode.bridgegraph; + } + let bridgegraph = (groupNode.bridgegraph = createGraph( + 'BRIDGEGRAPH', + GraphType.BRIDGE, + this.graphOptions, + )); + if (!node.parentNode || !('metagraph' in node.parentNode)) { + return bridgegraph; + } + let parentNode = node.parentNode; + let parentMetagraph = parentNode.metagraph; + let parentBridgegraph = this.getBridgegraph(parentNode.name); + // For each of the parent node's two Metaedge containing graphs, process + // each Metaedge involving this node. + _.each([parentMetagraph, parentBridgegraph], (parentGraph) => { + parentGraph + .edges() + .filter((e) => e.v === nodeName || e.w === nodeName) + .forEach((parentEdgeObj) => { + let inbound = parentEdgeObj.w === nodeName; + let parentMetaedge = parentGraph.edge(parentEdgeObj); + // The parent's Metaedge represents some number of underlying + // BaseEdges from the original full graph. For each of those, we need + // to determine which immediate child is involved and make sure + // there's a Metaedge in the bridgegraph that covers it. + _.each(parentMetaedge.baseEdgeList, (baseEdge) => { + // Based on the direction, figure out which is the descendant node + // and which is the 'other' node (sibling of parent or ancestor). + let [descendantName, otherName] = inbound ? [baseEdge.w, parentEdgeObj.v] : [baseEdge.v, parentEdgeObj.w]; + // Determine the immediate child containing this descendant node. + if (nodeName !== descendantName) { + let childName = this.getChildName(nodeName, descendantName); + if (!childName) { + return; + } + // Look for an existing Metaedge in the bridgegraph (or create a + // new one) that covers the relationship between child and other. + let bridgeEdgeObj = { + v: inbound ? otherName : childName, + w: inbound ? childName : otherName, + }; + let bridgeMetaedge = bridgegraph.edge(bridgeEdgeObj); + if (!bridgeMetaedge) { + bridgeMetaedge = createMetaedge(bridgeEdgeObj.v, bridgeEdgeObj.w) as any; + bridgeMetaedge.inbound = inbound; + bridgegraph.setEdge(bridgeEdgeObj.v, bridgeEdgeObj.w, bridgeMetaedge, baseEdge.attr?.id); + } + // Copy the BaseEdge from the parent's Metaedge into this + // bridgegraph Metaedge. + bridgeMetaedge.addBaseEdge(baseEdge, this); + } + }); + }); + }); + return bridgegraph; + } + /** + * Utility function for determining the name of the immediate child under a + * node for a given descendant path. If the descendant corresponds to no + * immediate child, an error is thrown. + */ + private getChildName(nodeName: string, descendantName: string): string { + // Walk up the hierarchy from the descendant to find the child. + let currentNode: Node = this.index[descendantName]; + if (!currentNode) { + return ''; + } + while (currentNode) { + if (currentNode.parentNode && currentNode.parentNode.name === nodeName) { + return currentNode.name; + } + currentNode = currentNode.parentNode; + } + throw Error('Could not find immediate child for descendant: ' + descendantName); + } + /** Given the name of a node, return its incoming metaedges. */ + getPredecessors(nodeName: string): Edges { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node with name: ' + nodeName); + } + let predecessors = this.getOneWayEdges(node, true); + // Add embedded predecessors, such as constants. + if (!node.isGroupNode) { + _.each((node).inEmbeddings, (embeddedNode) => { + _.each((node).inputs, (input) => { + if (input.name === embeddedNode.name) { + // Make a new metaedge holding the edge between the + // node and the in-embedding. + let metaedge = new MetaedgeImpl(embeddedNode.name, nodeName); + metaedge.addBaseEdge( + { + outputTensorKey: input.outputTensorKey, + isReferenceEdge: false, + v: embeddedNode.name, + w: nodeName, + }, + this, + ); + predecessors.regular.push(metaedge); + } + }); + }); + } + return predecessors; + } + /** + * Given the name of a node, return its outgoing metaedges. + * + * This is the inverse of getPredecessors(). See that method's documentation + * for an in-depth example. + */ + getSuccessors(nodeName: string): Edges { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node with name: ' + nodeName); + } + let successors = this.getOneWayEdges(node, false); + // Add embedded successors, such as summaries. + if (!node.isGroupNode) { + _.each((node).outEmbeddings, (embeddedNode) => { + _.each(embeddedNode.inputs, (input) => { + if (input.name === nodeName) { + // Make a new metaedge holding the edge between the + // node and the out-embedding. + let metaedge = new MetaedgeImpl(nodeName, embeddedNode.name); + metaedge.addBaseEdge( + { + outputTensorKey: input.outputTensorKey, + isReferenceEdge: false, + v: nodeName, + w: embeddedNode.name, + }, + this, + ); + successors.regular.push(metaedge); + } + }); + }); + } + return successors; + } + /** Helper method for getPredecessors and getSuccessors */ + private getOneWayEdges(node: GroupNode | OpNode, inEdges: boolean) { + let edges: Edges = { control: [], regular: [] }; + // A node with no parent cannot have any edges. + if (!node.parentNode || !node.parentNode.isGroupNode) { + return edges; + } + let parentNode = node.parentNode; + let metagraph = parentNode.metagraph; + let bridgegraph = this.getBridgegraph(parentNode.name); + findEdgeTargetsInGraph(metagraph, node, inEdges, edges); + findEdgeTargetsInGraph(bridgegraph, node, inEdges, edges); + return edges; + } + /** + * For a given GroupNode, get or calculate an object which describes a + * topological ordering of child nodes within that GroupNode's metagraph. + * + * This ordering is used when rendering bridge control edges which are + * sometimes backwards relative to the dataflow. + * + * For example, say we have a graph with two edges A->B and A->C, and we're + * interested in the ordering under ROOT. In this case, any of the following + * would be legitimate return values: + * + * - { 'A': 0, 'B': 1, 'C': 2 } -- most likely + * - { 'A': 0, 'B': 2, 'C': 1 } -- less likely + * - { 'A': 12, 'B': 100, 'C': 99 } -- unlikely, but still OK + * + * The algorithm does not guarantee that all numbers from 0-N (where N is + * the number of nodes) appear exactly once. Rather it guarantees that if + * there is a path between two nodes, the earlier one will have a lower + * number in the ordering hash. + * + * When generating the ordering, we ignore control Metaedges (those which + * represent only BaseEdges that have isControlDependency set to true). + * + * If there is no node with the specified name, an error is thrown. If the + * node with the specified name is not a group node, null is returned. + */ + getTopologicalOrdering(nodeName: string): { + [childName: string]: number; + } { + let node = this.index[nodeName]; + if (!node) { + throw Error('Could not find node with name: ' + nodeName); + } + if (!node.isGroupNode) { + return null!; + } + if (nodeName in this.orderings) { + return this.orderings[nodeName]; + } + // Mapping of a child node names to lists of their successors. + let successors: { + [childName: string]: string[]; + } = {}; + // Set of node names which have appeared as a destination. + let destinations: { + [childName: string]: boolean; + } = {}; + let metagraph = (node).metagraph; + _.each(metagraph.edges(), (e: any) => { + if (!metagraph.edge(e).numRegularEdges) { + return; // Skip control edges. + } + // Keep track of successors and destinations. + if (!(e.v in successors)) { + successors[e.v] = []; + } + successors[e.v].push(e.w); + destinations[e.w] = true; + }); + // Seed the queue with true sources (those that are not destinations). + let queue: string[] = _.difference(_.keys(successors), _.keys(destinations)); + // Produce an ordering by traversing the graph breadth first. + let ordering = (this.orderings[nodeName] = {}); + let index = 0; + while (queue.length) { + let childName = queue.shift(); + ordering[childName!] = index++; + _.each(successors[childName!], (succName) => queue.push(succName)); + delete successors[childName!]; // Prevent cycles from infinite looping. + } + return ordering; + } + /** + * Returns a d3 Ordinal function that can be used to look up the index of + * a node based on its template id. + * + * When templates update, the Hierarchy will dispatch an event + * `HierarchyEvent.TEMPLATES_UPDATED` to consumers. + */ + getTemplateIndex(): (string) => number | null { + if (!this.templates) { + return null!; + } + let templateNames = d3.keys(this.templates); + if (!templateNames.length) { + return null!; + } + let templateIndex = d3.scaleOrdinal().domain(templateNames).range(d3.range(0, templateNames.length)); + return (templateId: string) => templateIndex(templateId); + } + + /** + * Statically computes the templateId for every MetaNode in the graph, even + * if it is never rendered. This may be a very expensive call for large + * graphs. + */ + updateTemplates() { + tf_graph_util.time( + 'Finding similar subgraphs', + () => { + this.templates = template.detect(this, this.verifyTemplate); + this.dispatchEvent(HierarchyEvent.TEMPLATES_UPDATED); + }, + tb_debug.GraphDebugEventId.HIERARCHY_FIND_SIMILAR_SUBGRAPHS, + ); + } +} +/** + * Internal utility function - given a graph (should be either a metagraph or a + * bridgegraph) and a node which is known to be in that graph, determine + * the other ends of edges that involve that node in the direction specified + * by whether it's inbound. + * + * For example if you wanted to find the predecessors of a node, you'd call + * this method for the parent's metagraph and bridgegraph, specifying inbound + * as true (look at the source of inbound edges to the specified node). + * + * Discovered target names are appended to the targets array. + */ +function findEdgeTargetsInGraph(graph: graphlib.Graph, node: Node, inbound: boolean, targets: Edges): void { + let edges = inbound ? graph.inEdges(node.name) : graph.outEdges(node.name); + _.each(edges, (e) => { + let metaedge = graph.edge(e); + let targetList = metaedge.numRegularEdges ? targets.regular : targets.control; + targetList.push(metaedge as any); + }); +} +export interface HierarchyParams { + verifyTemplate: boolean; + seriesNodeMinSize: number; + // The initial map of explicit series group types. + seriesMap: SeriesGroupMap; + // This string is supplied to dagre as the 'rankdir' property for laying out + // the graph. TB, BT, LR, or RL. The default is 'BT' (bottom to top). + rankDirection: string; + // Whether to detect numeric patterns for series nodes using entire names of + // nodes. If false, only uses numeric suffixes to find patterns to collapse + // into series nodes. + useGeneralizedSeriesPatterns: boolean; +} +export const DefaultHierarchyParams: HierarchyParams = { + verifyTemplate: true, + seriesNodeMinSize: 5, + seriesMap: new Map(), + rankDirection: 'TB', + useGeneralizedSeriesPatterns: false, +}; +/** + * @param graph The raw graph. + * @param params Parameters used when building a hierarchy. + */ +export function build( + graph: tf_graph.SlimGraph, + params: HierarchyParams, + tracker: ProgressTracker | undefined, +): Promise { + const h = new Hierarchy(params); + const seriesNames: { + [name: string]: string; + } = {}; + return tf_graph_util + .runAsyncTask( + 'Adding nodes', + 30, + () => { + // Get all the possible device and XLA cluster names. + let deviceNames = {}; + let xlaClusterNames = {}; + _.each(graph.nodes, (node, nodeName) => { + if (node.device) { + deviceNames[node.device] = true; + } + if (node.xlaCluster) { + xlaClusterNames[node.xlaCluster] = true; + } + }); + h.devices = _.keys(deviceNames); + addNodesInVis(h, graph, ROOT_NAME); + }, + tracker, + tb_debug.GraphDebugEventId.HIERARCHY_ADD_NODES, + ) + .then(() => { + return tf_graph_util.runAsyncTask( + 'Detect series', + 30, + () => { + if (params.seriesNodeMinSize > 0) { + groupSeries( + h.root, + h, + seriesNames, + params.seriesNodeMinSize, + params.seriesMap, + params.useGeneralizedSeriesPatterns, + ); + } + }, + tracker, + tb_debug.GraphDebugEventId.HIERARCHY_DETECT_SERIES, + ); + }) + .then(() => { + return tf_graph_util.runAsyncTask( + 'Adding edges', + 40, + () => { + addEdgesInVis(h, graph, ROOT_NAME); + }, + tracker, + tb_debug.GraphDebugEventId.HIERARCHY_ADD_EDGES, + ); + }) + .then(() => { + return h; + }); +} +/** + * Updates hierarchy when the subgraph of a node is built. + * @param oldGraph + * @param slimGraph + */ +export function update(oldGraph: Hierarchy, slimGraph: tf_graph.SlimGraph, nodeName: string) { + let node = oldGraph.node(nodeName) as Metanode; + if (node) { + addNodesInVis(oldGraph, slimGraph, nodeName); + addEdgesInVis(oldGraph, slimGraph, nodeName); + } +} +/** + * Creates the metanodes in the hierarchical graph and assigns parent-child + * relationship between them in vis mode. + * @param h + * @param graph + * @param parentName + */ +function addNodesInVis(h: Hierarchy, graph: SlimGraph, parentName: string) { + const parentNode = h.node(parentName); + if (!(parentNode instanceof MetanodeImpl)) { + return; + } + const orderedNodes: Array<{ idx: number; name: string; node: any }> = []; + _.each([graph.nodes, graph.metaNodes], (nodes) => { + _.each(nodes, (node) => { + node.parentNode = parentNode; + orderedNodes.push({ + idx: node.nodeAttributes['_order'] ?? 0, + name: node.name, + node, + }); + h.setNode(node.name, node); + }); + }); + _.each( + orderedNodes.sort((a, b) => a.idx - b.idx), + (item) => { + parentNode.metagraph.setNode(item.name, item.node); + }, + ); +} +/** + * For each metanode in the hierarchical graph, this method adds: + * the edges in the metagraph. These are edges between nodes + * that share the same parent. + */ +function addEdges( + h: Hierarchy, + graph: SlimGraph, + seriesNames: { + [name: string]: string; + }, +) { + let nodeIndex = h.getNodeMap(); + // Ancestor paths for the source and destination nodes of an edge. These are + // reused for each edge rather than allocating new ones. It's about 10% faster + // than allocating new ones on each pass through the loop. + let sourcePath: string[] = []; + let destPath: string[] = []; + // Insert the ancestor path for a node into the provided array, including the + // node itself. Return the index of the last node inserted (always ROOT). + let getPath = (node: Node, path: string[]): number => { + let i = 0; + while (node) { + path[i++] = node.name; + node = node.parentNode; + } + return i - 1; + }; + _.each(graph.edges, (baseEdge) => { + // Get the hierarchical paths for the source and destination of the edge. + let sourceAncestorIndex = getPath(graph.nodes[baseEdge.v!], sourcePath); + let destAncestorIndex = getPath(graph.nodes[baseEdge.w!], destPath); + // If the hierarchical path cannot be found for either endpoint, then we + // cannot create the edge. This happens for example when a node has a + // control dependency on a summary node, which are embedded. + if (sourceAncestorIndex === -1 || destAncestorIndex === -1) { + return; + } + // Find the lowest shared ancestor between source and dest by looking for + // the highest nodes that differ between their ancestor paths. + while (sourcePath[sourceAncestorIndex] === destPath[destAncestorIndex]) { + sourceAncestorIndex--; + destAncestorIndex--; + if (sourceAncestorIndex < 0 || destAncestorIndex < 0) { + // This would only occur if the two nodes were the same (a cycle in the + // graph), or if one endpoint was a strict ancestor of the other. The + // latter shouldn't happen because we rename nodes which are both + // metanodes and op nodes. E.g. 'A/B' becomes 'A/B/(B)'. + throw Error('No difference found between ancestor paths.'); + } + } + let sharedAncestorNode = nodeIndex[sourcePath[sourceAncestorIndex + 1]]; + let sourceAncestorName = sourcePath[sourceAncestorIndex]; + let destAncestorName = destPath[destAncestorIndex]; + // Find or create the Metaedge which should contain this BaseEdge inside + // the shared ancestor. + let metaedge = sharedAncestorNode.metagraph.edge(sourceAncestorName, destAncestorName); + if (!metaedge) { + metaedge = createMetaedge(sourceAncestorName, destAncestorName) as any; + sharedAncestorNode.metagraph.setEdge(sourceAncestorName, destAncestorName, metaedge); + } + if (!sharedAncestorNode.hasNonControlEdges) { + sharedAncestorNode.hasNonControlEdges = true; + } + metaedge.addBaseEdge(baseEdge, h); + }); +} +/** + * Create edges in the metanode. + * @param h + * @param graph + * @param nodeName + */ +function addEdgesInVis(h: Hierarchy, graph: SlimGraph, nodeName: string) { + const metaNode = h.node(nodeName); + if (!(metaNode instanceof MetanodeImpl)) { + return; + } + _.each(graph.edges, (baseEdge) => { + if (!(baseEdge.w?.includes(tf_graph.NAMESPACE_DELIM) || baseEdge.v?.includes(tf_graph.NAMESPACE_DELIM))) { + if ( + !(baseEdge.v! in graph.nodes || baseEdge.v! in graph.metaNodes) || + !(baseEdge.w! in graph.nodes || baseEdge.w! in graph.metaNodes) + ) { + return; + } + } + const srcName = baseEdge.v!; + const dstName = baseEdge.w!; + let metaedge = metaNode.metagraph.edge(srcName, dstName, baseEdge.attr?.id); + if (!metaedge) { + metaedge = createMetaedge(srcName, dstName) as any; + metaNode.metagraph.setEdge(srcName, dstName, metaedge, baseEdge.attr?.id); + } + metaedge.addBaseEdge(baseEdge, h, true); + }); +} +/** + * Using the hierarchy template information, detect series in the provided + * metanode. For each detected series, create a new SeriesNode + * and remove series members from the metanode's metagraph and move them to + * the new series node's metagraph. + * + * @param metanode + * @param hierarchy + * @param seriesNames Map of node names to their series they are contained in. + * This should be provided empty and is populated by this method. + * @param threshold If the series has this many nodes or more, then group them + * into a series. + * @param map Map of series names to their series grouping type, if one has + * been set. + * @param useGeneralizedSeriesPatterns Whether to use find patterns for series + * nodes using any parts of names of nodes. If false, only uses patterns + * discovered within numeric suffixes of nodes names. + * @return A dictionary from node name to series node name that contains the + * node. + */ +function groupSeries( + metanode: Metanode, + hierarchy: Hierarchy, + seriesNames: { + [name: string]: string; + }, + threshold: number, + seriesMap: SeriesGroupMap, + useGeneralizedSeriesPatterns: boolean, +) { + let metagraph = metanode.metagraph; + _.each(metagraph.nodes(), (n) => { + let child = metagraph.node(n); + if (child.type === (tf_graph.NodeType.META || tf_graph.NodeType.API_LIST)) { + groupSeries( + child as unknown as Metanode, + hierarchy, + seriesNames, + threshold, + seriesMap, + useGeneralizedSeriesPatterns, + ); + } + }); + let clusters = clusterNodes(metagraph); + const detectSeriesMethod = useGeneralizedSeriesPatterns + ? detectSeriesAnywhereInNodeName + : detectSeriesUsingNumericSuffixes; + let seriesDict = detectSeriesMethod(clusters, metagraph, hierarchy.graphOptions); + // Add each series node to the graph and add its grouped children to its own + // metagraph. + _.each(seriesDict, function (seriesNode: SeriesNode, seriesName: string) { + let nodeMemberNames = seriesNode.metagraph.nodes(); + _.each(nodeMemberNames, (n) => { + let child = metagraph.node(n) as any; + if (!child.owningSeries) { + child.owningSeries = seriesName; + } + }); + // If the series contains less than the threshold number of nodes, then set + // this series to be shown ungrouped in the map. + if ( + nodeMemberNames.length < threshold && + hierarchy.getSeriesGroupType(seriesNode.name) === tf_graph.SeriesGroupingType.GROUP + ) { + hierarchy.setSeriesGroupType(seriesNode.name, tf_graph.SeriesGroupingType.UNGROUP); + } + // If the series is in the map as ungrouped then do not group the series. + if (hierarchy.getSeriesGroupType(seriesNode.name) === tf_graph.SeriesGroupingType.UNGROUP) { + return; + } + hierarchy.setNode(seriesName, seriesNode); // add to the index + metagraph.setNode(seriesName, seriesNode); + _.each(nodeMemberNames, (n) => { + let child = metagraph.node(n) as any; + seriesNode.metagraph.setNode(n, child); + seriesNode.parentNode = child.parentNode; + seriesNode.cardinality++; + if (child.device != null) { + seriesNode.deviceHistogram[child.device] = (seriesNode.deviceHistogram[child.device] || 0) + 1; + } + if (child.xlaCluster != null) { + seriesNode.xlaClusterHistogram[child.xlaCluster] = (seriesNode.xlaClusterHistogram[child.xlaCluster] || 0) + 1; + } + // Increment parents appropriate compatibility count + if (child.compatible) { + seriesNode.compatibilityHistogram.compatible = (seriesNode.compatibilityHistogram.compatible || 0) + 1; + } else { + seriesNode.compatibilityHistogram.incompatible = (seriesNode.compatibilityHistogram.incompatible || 0) + 1; + } + // Increment capability counts for in and out embeddings + _.each(child.inEmbeddings, (inNode) => { + if (inNode.compatible) { + seriesNode.compatibilityHistogram.compatible = (seriesNode.compatibilityHistogram.compatible || 0) + 1; + } else { + seriesNode.compatibilityHistogram.incompatible = (seriesNode.compatibilityHistogram.incompatible || 0) + 1; + } + }); + _.each(child.outEmbeddings, (outNode) => { + if (outNode.compatible) { + seriesNode.compatibilityHistogram.compatible = (seriesNode.compatibilityHistogram.compatible || 0) + 1; + } else { + seriesNode.compatibilityHistogram.incompatible = (seriesNode.compatibilityHistogram.incompatible || 0) + 1; + } + }); + child.parentNode = seriesNode; + seriesNames[n] = seriesName; + // Remove now-grouped node from its original parent's metagraph. + metagraph.removeNode(n); + }); + }); +} +/** + * Cluster op-nodes with similar op. This examines only the direct children of + * the metagraph, does not recursively check descendants. + * @return A map from op to a list of node names. + */ +function clusterNodes(metagraph: graphlib.Graph): { + [clusterId: string]: string[]; +} { + let result: { + [clusterId: string]: string[]; + } = {}; + return _.reduce( + metagraph.nodes(), + ( + clusters: { + [clusterId: string]: string[]; + }, + n: string, + ) => { + let child = metagraph.node(n); + if (child.type === (NodeType.META || NodeType.API_LIST)) { + // skip metanodes + return clusters; + } + let template = (child as any).op; + if (template) { + clusters[template] = clusters[template] || []; + clusters[template].push(child.name); + } + return clusters; + }, + result, + ); +} +/** + * For each cluster of op-nodes based op type, try to detect groupings. + * Infer series name using by trying to find pattern '' towards the end + * of node names. + * + * @param clusters Dictionary output from clusterNodes(). + * @param metagraph + * @return A dictionary from series name => seriesNode + */ +function detectSeriesUsingNumericSuffixes( + clusters: { + [clusterId: string]: string[]; + }, + metagraph: graphlib.Graph, + graphOptions: tf_graph.LabeledGraphOptions, +): { + [seriesName: string]: SeriesNode; +} { + let seriesDict: { + [seriesName: string]: SeriesNode; + } = {}; + _.each(clusters, function (members, clusterId: string) { + if (members.length <= 1) { + return; + } // isolated clusters can't make series + /** @type {Object} A dictionary mapping seriesName to seriesInfoArray, + * which is an array that contains objects with name, id, prefix, suffix, + * and parent properties. + */ + let candidatesDict: { + [seriesName: string]: SeriesNode[]; + } = {}; + // Group all nodes that have the same name, with the exception of a + // number at the end of the name after an underscore, which is allowed to + // vary. + _.each(members, function (name: string) { + const isGroup = name.charAt(name.length - 1) === '*'; + const namepath = name.split('/'); + const leaf = namepath[namepath.length - 1]; + const parent = namepath.slice(0, namepath.length - 1).join('/'); + const matches = leaf.match(/^(\D*)(\d+)$/); + let prefix; + let id; + let suffix = ''; + if (matches) { + // if found '' in the name, assign id. + prefix = matches[1]; // the front non-numeric characters + id = matches[2]; // the digits + } else { + // for node without '_', make them zero-th items. + prefix = isGroup ? leaf.substr(0, leaf.length - 1) : leaf; + id = 0; + suffix = isGroup ? '*' : ''; + } + const seriesName = getSeriesNodeName(prefix, suffix, parent); + candidatesDict[seriesName] = candidatesDict[seriesName] || []; + const seriesNode = createSeriesNode(prefix, suffix, parent, +id, name, graphOptions); + candidatesDict[seriesName].push(seriesNode); + }); + // In each group of nodes, group nodes in bunches that have monotonically + // increasing numbers in their names. Each of these bunches is a series. + _.each(candidatesDict, function (seriesInfoArray: SeriesNode[], seriesName) { + if (seriesInfoArray.length < 2) { + return; + } + seriesInfoArray.sort(function (a, b) { + return +a.clusterId - +b.clusterId; + }); + // Loop through the nodes sorted by its detected series number, grouping + // all nodes with monotonically-increasing series numbers. + let seriesNodes = [seriesInfoArray[0]]; + for (let index = 1; index < seriesInfoArray.length; index++) { + let nextNode = seriesInfoArray[index]; + if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + 1) { + seriesNodes.push(nextNode); + continue; + } + addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph, graphOptions); + seriesNodes = [nextNode]; + } + addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph, graphOptions); + }); + }); + return seriesDict; +} +/** + * For each cluster of op-nodes based op type, try to detect groupings. + * Infer series name using by trying to find a pattern of numbers + * anywhere within node names. + * + * @param clusters Dictionary output from clusterNodes(). + * @param metagraph + * @return A dictionary from series name => seriesNode + */ +function detectSeriesAnywhereInNodeName( + clusters: { + [clusterId: string]: string[]; + }, + metagraph: graphlib.Graph, + graphOptions: tf_graph.LabeledGraphOptions, +): { + [seriesName: string]: SeriesNode; +} { + let seriesDict: { + [seriesName: string]: SeriesNode; + } = {}; + _.each(clusters, function (members, clusterId: string) { + if (members.length <= 1) { + return; + } // isolated clusters can't make series + /** + * @type {Object} A dictionary mapping a series name to a SeriesNode. + */ + let forwardDict: { + [seriesName: string]: SeriesNode; + } = {}; + /** + * @type {Object} A dictionary mapping member name to an array of series + * names this member could potentially be grouped under and the + * corresponding ids. + */ + let reverseDict: { + [seriesName: string]: any[]; + } = {}; + // Group all nodes that have the same name, with the exception of a + // number at the end of the name after an underscore, which is allowed to + // vary. + _.each(members, function (name: string) { + let isGroup = name.charAt(name.length - 1) === '*'; + let namepath = name.split('/'); + let leaf = namepath[namepath.length - 1]; + let parent = namepath.slice(0, namepath.length - 1).join('/'); + const numRegex = /(\d+)/g; + let matches = []; + let matchResult; + let prefix; + let id; + let suffix; + let seriesName; + let matched = 0; + // Scan over the entire leaf name and match any possible numbers, + // and put the results into corresponding dictionaries. + while ((matchResult = numRegex.exec(leaf))) { + ++matched; + prefix = leaf.slice(0, matchResult.index); + id = matchResult[0]; + suffix = leaf.slice(matchResult.index + matchResult[0].length); + seriesName = getSeriesNodeName(prefix, suffix, parent); + forwardDict[seriesName] = forwardDict[seriesName]; + if (!forwardDict[seriesName]) { + forwardDict[seriesName] = createSeriesNode(prefix, suffix, parent, +id, name, graphOptions); + } + forwardDict[seriesName].ids.push(id); + reverseDict[name] = reverseDict[name] || []; + reverseDict[name].push([seriesName, id]); + } + if (matched < 1) { + prefix = isGroup ? leaf.substr(0, leaf.length - 1) : leaf; + id = 0; + suffix = isGroup ? '*' : ''; + seriesName = getSeriesNodeName(prefix, suffix, parent); + forwardDict[seriesName] = forwardDict[seriesName]; + if (!forwardDict[seriesName]) { + forwardDict[seriesName] = createSeriesNode(prefix, suffix, parent, +id, name, graphOptions); + } + forwardDict[seriesName].ids.push(id); + reverseDict[name] = reverseDict[name] || []; + reverseDict[name].push([seriesName, id]); + } + }); + /** @type {Object} A dictionary mapping seriesName to seriesInfoArray, + * which is an array that contains objects with name, id, prefix, suffix, + * and parent properties. + */ + var candidatesDict: { + [seriesName: string]: SeriesNode[]; + } = {}; + // For each of the member, put it into the maximum possible series, + // and create candidatesDict accordingly. + _.each(reverseDict, function (seriesNameIdArray, name) { + seriesNameIdArray.sort(function (a, b) { + return forwardDict[b[0]].ids.length - forwardDict[a[0]].ids.length; + }); + var seriesName = seriesNameIdArray[0][0]; + var id = seriesNameIdArray[0][1]; + candidatesDict[seriesName] = candidatesDict[seriesName] || []; + const namepath = name.split('/'); + const leaf = namepath[namepath.length - 1]; + const parent = namepath.slice(0, namepath.length - 1).join('/'); + var seriesNode = createSeriesNode( + forwardDict[seriesName].prefix, + forwardDict[seriesName].suffix, + parent, + +id, + name, + graphOptions, + ); + candidatesDict[seriesName].push(seriesNode); + }); + // In each group of nodes, group nodes in bunches that have monotonically + // increasing numbers in their names. Each of these bunches is a series. + _.each(candidatesDict, function (seriesInfoArray: SeriesNode[], seriesName) { + if (seriesInfoArray.length < 2) { + return; + } + seriesInfoArray.sort(function (a, b) { + return +a.clusterId - +b.clusterId; + }); + // Loop through the nodes sorted by its detected series number, grouping + // all nodes with monotonically-increasing series numbers. + let seriesNodes = [seriesInfoArray[0]]; + for (let index = 1; index < seriesInfoArray.length; index++) { + let nextNode = seriesInfoArray[index]; + if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + 1) { + seriesNodes.push(nextNode); + continue; + } + addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph, graphOptions); + seriesNodes = [nextNode]; + } + addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph, graphOptions); + }); + }); + return seriesDict; +} +/** + * Add a series to the provided dictionary mapping series names to series. + * + * @param seriesNodes the nodes in the series. Contains + * name, id, prefix, suffix and parent properties of the node. + * @param seriesDict the dictionary of series + * @param clusterId ID of the template of the nodes of the series + * @param metagraph + * @param graphOptions + */ +function addSeriesToDict( + seriesNodes: SeriesNode[], + seriesDict: { + [seriesName: string]: SeriesNode; + }, + clusterId: number, + metagraph: graphlib.Graph, + graphOptions: tf_graph.LabeledGraphOptions, +) { + if (seriesNodes.length > 1) { + let curSeriesName = getSeriesNodeName( + seriesNodes[0].prefix, + seriesNodes[0].suffix, + seriesNodes[0].parent, + seriesNodes[0].clusterId, + seriesNodes[seriesNodes.length - 1].clusterId, + ); + let curSeriesNode = createSeriesNode( + seriesNodes[0].prefix, + seriesNodes[0].suffix, + seriesNodes[0].parent, + clusterId, + curSeriesName, + graphOptions, + ); + _.each(seriesNodes, function (node) { + curSeriesNode.ids.push(node.clusterId); + curSeriesNode.metagraph.setNode(node.name, metagraph.node(node.name)); + }); + seriesDict[curSeriesName] = curSeriesNode; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/layout.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/layout.ts new file mode 100644 index 0000000000000000000000000000000000000000..d1cf589fddc29beaf464b8bcb188dde46b1bce1e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/layout.ts @@ -0,0 +1,725 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; +import * as dagre from 'dagre'; +import { graphlib } from 'dagre'; +import * as _ from 'lodash'; +import { NodeType } from './graph'; +import * as render from './render'; + +export const PARAMS = { + animation: { + /** Default duration for graph animations in ms. */ + duration: 250, + }, + graph: { + /** Graph parameter for metanode. */ + meta: { + /** + * Dagre's nodesep param - number of pixels that + * separate nodes horizontally in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + nodeSep: 5, + /** + * Dagre's ranksep param - number of pixels + * between each rank in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5, + }, + /** Graph parameter for metanode. */ + series: { + /** + * Dagre's nodesep param - number of pixels that + * separate nodes horizontally in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + nodeSep: 5, + /** + * Dagre's ranksep param - number of pixels + * between each rank in the layout. + * + * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout + */ + rankSep: 25, + /** + * Dagre's edgesep param - number of pixels that separate + * edges horizontally in the layout. + */ + edgeSep: 5, + }, + /** + * Padding is used to correctly position the graph SVG inside of its parent + * element. The padding amounts are applied using an SVG transform of X and + * Y coordinates. + */ + padding: { paddingTop: 40, paddingLeft: 20 }, + }, + subscene: { + meta: { + paddingTop: 10, + paddingBottom: 10, + paddingLeft: 30, + paddingRight: 30, + /** + * Used to leave room for the label on top of the highest node in + * the core graph. + */ + labelHeight: 20, + /** X-space between each extracted node and the core graph. */ + extractXOffset: 15, + /** Y-space between each extracted node. */ + extractYOffset: 20, + }, + series: { + paddingTop: 10, + paddingBottom: 10, + paddingLeft: 10, + paddingRight: 10, + labelHeight: 10, + }, + }, + nodeSize: { + /** Size of meta nodes. */ + meta: { + radius: 5, + width: 60, + maxLabelWidth: 200, + /** A scale for the node's height based on number of nodes inside */ + // Hack - set this as an any type to avoid issues in exporting a type + // from an external module. + height: (d3 as any).scaleLinear().domain([1, 200]).range([15, 60]).clamp(true), + /** The radius of the circle denoting the expand button. */ + expandButtonRadius: 3, + }, + api_list: { + radius: 5, + width: 60, + maxLabelWidth: 200, + /** A scale for the node's height based on number of nodes inside */ + // Hack - set this as an any type to avoid issues in exporting a type + // from an external module. + height: (d3 as any).scaleLinear().domain([1, 200]).range([15, 60]).clamp(true), + /** The radius of the circle denoting the expand button. */ + expandButtonRadius: 3, + }, + /** Size of op nodes. */ + op: { + width: 30, + height: 12, + radius: 6, + labelOffset: -12, + maxLabelWidth: 60, + }, + /** Size of series nodes. */ + series: { + expanded: { + // For expanded series nodes, width and height will be + // computed to account for the subscene. + radius: 10, + labelOffset: 0, + }, + vertical: { + // When unexpanded, series whose underlying metagraphs contain + // one or more non-control edges will show as a vertical stack + // of ellipses. + width: 16, + height: 13, + labelOffset: -13, + }, + horizontal: { + // When unexpanded, series whose underlying metagraphs contain + // no non-control edges will show as a horizontal stack of + // ellipses. + width: 24, + height: 8, + radius: 10, + labelOffset: -10, + }, + }, + /** Size of bridge nodes. */ + bridge: { + // NOTE: bridge nodes will normally be invisible, but they must + // take up some space so that the layout step leaves room for + // their edges. + width: 20, + height: 20, + radius: 2, + labelOffset: 0, + }, + }, + shortcutSize: { + /** Size of shortcuts for op nodes */ + op: { width: 10, height: 4 }, + /** Size of shortcuts for meta nodes */ + meta: { width: 12, height: 4, radius: 1 }, + /** Size of shortcuts for series nodes */ + series: { + width: 14, + height: 4, + }, + api_list: { width: 12, height: 4, radius: 1 }, + }, + annotations: { + /** Maximum possible width of the bounding box for in annotations */ + inboxWidth: 50, + /** Maximum possible width of the bounding box for out annotations */ + outboxWidth: 50, + /** X-space between the shape and each annotation-node. */ + xOffset: 10, + /** Y-space between each annotation-node. */ + yOffset: 3, + /** X-space between each annotation-node and its label. */ + labelOffset: 2, + /** Defines the max width for annotation label */ + maxLabelWidth: 40, + }, + constant: { size: { width: 4, height: 4 } }, + series: { + /** Maximum number of repeated item for unexpanded series node. */ + maxStackCount: 3, + /** + * Positioning offset ratio for collapsed stack + * of parallel series (series without edges between its members). + */ + parallelStackOffsetRatio: 0.2, + /** + * Positioning offset ratio for collapsed stack + * of tower series (series with edges between its members). + */ + towerStackOffsetRatio: 0.5, + }, + minimap: { + /** The maximum width/height the minimap can have. */ + size: 150, + }, +}; +/** + * The minimum width we confer upon the auxiliary nodes section if functions + * also appear. Without enforcing this minimum, metanodes in the function + * library section could jut into the auxiliary nodes section because the + * title "Auxiliary Nodes" is longer than the width of the auxiliary nodes + * section itself. + */ +export const MIN_AUX_WIDTH = 140; +/** + * Keep this number as the same as 'maxMetanodeLabelLength' in 'tf-graph-scene' + */ +export const MAX_TEXT_LENGTH = 50; +/** + * 6 pixels per character. + */ +export const CHARACTER_WIDTH = 6; +/** Calculate layout for a scene of a group node. */ +export function layoutScene(renderNodeInfo: render.RenderGroupNodeInfo): void { + // Update layout, size, and annotations of its children nodes and edges. + if (renderNodeInfo.node.isGroupNode) { + layoutChildren(renderNodeInfo); + } + // Update position of its children nodes and edges + if (renderNodeInfo.node.type === NodeType.META) { + layoutMetanode(renderNodeInfo); + } else if (renderNodeInfo.node.type === NodeType.SERIES) { + layoutSeriesNode(renderNodeInfo); + } else if (renderNodeInfo.node.type === NodeType.API_LIST) { + layoutMetanode(renderNodeInfo); + } +} +/** + * Updates the total width of an unexpanded node which includes the size of its + * in and out annotations. + */ +function updateTotalWidthOfNode(renderInfo: render.RenderNodeInfo): void { + renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ? PARAMS.annotations.inboxWidth : 0; + renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ? PARAMS.annotations.outboxWidth : 0; + // Assign the width of the core box (the main shape of the node). + renderInfo.coreBox.width = renderInfo.width; + renderInfo.coreBox.height = renderInfo.height; + // TODO: Account for font width rather than using a magic number. + let labelLength = renderInfo.displayName.length; + // Compute the total width of the node. + const maxLabelWidth = + renderInfo.node.type === NodeType.OP ? PARAMS.nodeSize.op.maxLabelWidth : PARAMS.nodeSize.meta.maxLabelWidth; + renderInfo.width = Math.max( + renderInfo.coreBox.width + renderInfo.inboxWidth + renderInfo.outboxWidth, + Math.min(labelLength * CHARACTER_WIDTH, maxLabelWidth), + ); +} +/** + * Update layout, size, and annotations of its children nodes and edges. + */ +function layoutChildren(renderNodeInfo: render.RenderGroupNodeInfo): void { + let children = renderNodeInfo.coreGraph + .nodes() + .map((n) => { + return renderNodeInfo.coreGraph.node(n); + }) + .concat( + renderNodeInfo.isolatedInExtract, + renderNodeInfo.isolatedOutExtract, + renderNodeInfo.libraryFunctionsExtract, + ); + _.each(children, (childNodeInfo) => { + // Set size of each child + switch (childNodeInfo.node.type) { + case NodeType.OP: + _.extend(childNodeInfo, PARAMS.nodeSize.op); + break; + case NodeType.BRIDGE: + _.extend(childNodeInfo, PARAMS.nodeSize.bridge); + break; + case NodeType.META: + if (!childNodeInfo.expanded) { + // Set fixed width and scalable height based on cardinality + _.extend(childNodeInfo, PARAMS.nodeSize.meta); + childNodeInfo.height = PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); + childNodeInfo.width = Math.max( + childNodeInfo.width, + Math.min(childNodeInfo.displayName.length, MAX_TEXT_LENGTH) * CHARACTER_WIDTH, + ); + } else { + let childGroupNodeInfo = childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. + } + break; + case NodeType.SERIES: + if (childNodeInfo.expanded) { + _.extend(childNodeInfo, PARAMS.nodeSize.series.expanded); + let childGroupNodeInfo = childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. + } else { + let childGroupNodeInfo = childNodeInfo; + let seriesParams = childGroupNodeInfo.node.hasNonControlEdges + ? PARAMS.nodeSize.series.vertical + : PARAMS.nodeSize.series.horizontal; + _.extend(childNodeInfo, seriesParams); + } + break; + case NodeType.API_LIST: + if (!childNodeInfo.expanded) { + // Set fixed width and scalable height based on cardinality + _.extend(childNodeInfo, PARAMS.nodeSize.meta); + childNodeInfo.height = PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); + childNodeInfo.width = Math.max( + childNodeInfo.width, + Math.min(childNodeInfo.displayName.length, MAX_TEXT_LENGTH) * CHARACTER_WIDTH, + ); + } else { + let childGroupNodeInfo = childNodeInfo; + layoutScene(childGroupNodeInfo); // Recursively layout its subscene. + } + break; + default: + throw Error('Unrecognized node type: ' + childNodeInfo.node.type); + } + // Compute total width of un-expanded nodes. Width of expanded nodes + // has already been computed. + if (!childNodeInfo.expanded) { + updateTotalWidthOfNode(childNodeInfo); + } + // Layout each child's annotations + layoutAnnotation(childNodeInfo); + }); +} +/** + * Calculate layout for a graph using dagre + * @param graph the graph to be laid out + * @param params layout parameters + * @return width and height of the core graph + */ +function dagreLayout(graph: graphlib.Graph, params): { height: number; width: number } { + _.extend(graph.graph(), { + nodesep: params.nodeSep, + ranksep: params.rankSep, + edgesep: params.edgeSep, + }); + let bridgeNodeNames: any[] = []; + let nonBridgeNodeNames: any[] = []; + // Split out nodes into bridge and non-bridge nodes, and calculate the total + // width we should use for bridge nodes. + _.each(graph.nodes(), (nodeName) => { + let nodeInfo = graph.node(nodeName); + if (nodeInfo.node.type === NodeType.BRIDGE) { + bridgeNodeNames.push(nodeName); + } else { + nonBridgeNodeNames.push(nodeName); + } + }); + // If there are no non-bridge nodes, then the graph has zero size. + if (!nonBridgeNodeNames.length) { + return { + width: 0, + height: 0, + }; + } + dagre.layout(graph); + // Calculate the true bounding box of the graph by iterating over nodes and + // edges rather than accepting dagre's word for it. In particular, we should + // ignore the extra-wide bridge nodes and bridge edges, and allow for + // annotation boxes and labels. + let minX = Infinity; + let minY = Infinity; + let maxX = -Infinity; + let maxY = -Infinity; + _.each(nonBridgeNodeNames, (nodeName) => { + let nodeInfo = graph.node(nodeName); + let w = 0.5 * nodeInfo.width; + let x1 = nodeInfo.x - w; + let x2 = nodeInfo.x + w; + minX = x1 < minX ? x1 : minX; + maxX = x2 > maxX ? x2 : maxX; + // TODO: Account for the height of labels above op nodes here. + let h = 0.5 * nodeInfo.height; + let y1 = nodeInfo.y - h; + let y2 = nodeInfo.y + h; + minY = y1 < minY ? y1 : minY; + maxY = y2 > maxY ? y2 : maxY; + }); + + _.each(graph.nodes(), (nodeName) => { + let nodeInfo = graph.node(nodeName); + nodeInfo.x -= minX; + nodeInfo.y -= minY; + }); + return { + width: maxX - minX, + height: maxY - minY, + }; +} +/** Layout a metanode. Only called for an expanded node. */ +function layoutMetanode(renderNodeInfo: render.RenderGroupNodeInfo): void { + // First, copy params specific to meta nodes onto this render info object. + let params = PARAMS.subscene.meta; + _.extend(renderNodeInfo, params); + // Invoke dagre.layout() on the core graph and record the bounding box + // dimensions. + _.extend(renderNodeInfo.coreBox, dagreLayout(renderNodeInfo.coreGraph, PARAMS.graph.meta)); + // Calculate the position of nodes in isolatedInExtract relative to the + // top-left corner of inExtractBox (the bounding box for all inExtract nodes) + // and calculate the size of the inExtractBox. + let maxInExtractWidth = renderNodeInfo.isolatedInExtract.length + ? _.maxBy(renderNodeInfo.isolatedInExtract, (renderNode) => renderNode.width)!.width + : null; + renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ? maxInExtractWidth : 0; + renderNodeInfo.inExtractBox.height = _.reduce( + renderNodeInfo.isolatedInExtract, + (height, child, i) => { + let yOffset = i > 0 ? params.extractYOffset : 0; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; + }, + 0, + ); + // Calculate the position of nodes in isolatedOutExtract relative to the + // top-left corner of outExtractBox (the bounding box for all outExtract + // nodes) and calculate the size of the outExtractBox. + let maxOutExtractWidth = renderNodeInfo.isolatedOutExtract.length + ? _.maxBy(renderNodeInfo.isolatedOutExtract, (renderNode) => renderNode.width)!.width + : null; + renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ? maxOutExtractWidth : 0; + renderNodeInfo.outExtractBox.height = _.reduce( + renderNodeInfo.isolatedOutExtract, + (height, child, i) => { + let yOffset = i > 0 ? params.extractYOffset : 0; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; + }, + 0, + ); + // Calculate the position of nodes in libraryFunctionsExtract relative to the + // top-left corner of libraryFunctionsBox (the bounding box for all library + // function nodes) and calculate the size of the libraryFunctionsBox. + let maxLibraryFunctionsWidth = renderNodeInfo.libraryFunctionsExtract.length + ? _.maxBy(renderNodeInfo.libraryFunctionsExtract, (renderNode) => renderNode.width)!.width + : null; + renderNodeInfo.libraryFunctionsBox.width = maxLibraryFunctionsWidth != null ? maxLibraryFunctionsWidth : 0; + renderNodeInfo.libraryFunctionsBox.height = _.reduce( + renderNodeInfo.libraryFunctionsExtract, + (height, child, i) => { + let yOffset = i > 0 ? params.extractYOffset : 0; + // use width/height here to avoid overlaps between extracts + child.x = 0; + child.y = height + yOffset + child.height / 2; + return height + yOffset + child.height; + }, + 0, + ); + // Compute the total padding between the core graph, in-extract and + // out-extract boxes. + let numParts = 0; + if (renderNodeInfo.isolatedInExtract.length > 0) { + numParts++; + } + if (renderNodeInfo.isolatedOutExtract.length > 0) { + numParts++; + } + if (renderNodeInfo.libraryFunctionsExtract.length > 0) { + numParts++; + } + if (renderNodeInfo.coreGraph.nodeCount() > 0) { + numParts++; + } + let offset = PARAMS.subscene.meta.extractXOffset; + let padding = numParts <= 1 ? 0 : numParts * offset; + // Add the in-extract and out-extract width to the core box width. Do not let + // the auxiliary width be too small, lest it be smaller than the title. + renderNodeInfo.coreBox.width += padding + renderNodeInfo.libraryFunctionsBox.width + padding; + renderNodeInfo.coreBox.height = + params.labelHeight + + Math.max( + renderNodeInfo.inExtractBox.height, + renderNodeInfo.coreBox.height, + renderNodeInfo.libraryFunctionsBox.height, + renderNodeInfo.outExtractBox.height, + ); + // Determine the whole metanode's width (from left to right). + renderNodeInfo.width = + Math.max(renderNodeInfo.displayName.length * CHARACTER_WIDTH, renderNodeInfo.coreBox.width) + + params.paddingLeft + + params.paddingRight; + // Determine the whole metanode's height (from top to bottom). + renderNodeInfo.height = renderNodeInfo.paddingTop + renderNodeInfo.coreBox.height + renderNodeInfo.paddingBottom; +} +/** + * Calculate layout for series node's core graph. Only called for an expanded + * series. + */ +function layoutSeriesNode(node: render.RenderGroupNodeInfo): void { + let graph = node.coreGraph; + let params = PARAMS.subscene.series; + _.extend(node, params); + // Layout the core. + _.extend(node.coreBox, dagreLayout(node.coreGraph, PARAMS.graph.series)); + _.each(graph.nodes(), (nodeName) => { + graph.node(nodeName).excluded = false; + }); + // Series do not have in/outExtractBox so no need to include them here. + node.width = node.coreBox.width + params.paddingLeft + params.paddingRight; + node.height = node.coreBox.height + params.paddingTop + params.paddingBottom; +} +/** + * Calculate layout for annotations of a given node. + * This will modify positions of the given node and its annotations. + * + * @see tf.graph.render.Node and tf.graph.render.Annotation + * for description of each property of each render node. + * + */ +function layoutAnnotation(renderNodeInfo: render.RenderNodeInfo): void { + // If the render node is an expanded metanode, then its annotations will not + // be visible and we should skip the annotation calculations. + if (renderNodeInfo.expanded) { + return; + } + let inAnnotations = renderNodeInfo.inAnnotations.list; + let outAnnotations = renderNodeInfo.outAnnotations.list; + // Calculate size for in-annotations + _.each(inAnnotations, (a) => sizeAnnotation(a)); + // Calculate size for out-annotations + _.each(outAnnotations, (a) => sizeAnnotation(a)); + let params = PARAMS.annotations; + // Calculate annotation node position (a.dx, a.dy) + // and total height for in-annotations + // After this chunk of code: + // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset) + let inboxHeight = _.reduce( + inAnnotations, + (height, a, i) => { + let yOffset = i > 0 ? params.yOffset : 0; + a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset; + a.dy = height + yOffset + a.height / 2; + return height + yOffset + a.height; + }, + 0, + ); + _.each(inAnnotations, (a) => { + a.dy -= inboxHeight / 2; + a.labelOffset = params.labelOffset; + }); + // Calculate annotation node position (a.dx, a.dy) + // and total height for out-annotations + // After this chunk of code: + // outboxHeight = sum of annotation heights + + // (annotation.length - 1 * yOffset) + let outboxHeight = _.reduce( + outAnnotations, + (height, a, i) => { + let yOffset = i > 0 ? params.yOffset : 0; + a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset; + a.dy = height + yOffset + a.height / 2; + return height + yOffset + a.height; + }, + 0, + ); + _.each(outAnnotations, (a) => { + // adjust by (half of ) the total height + // so dy is relative to the host node's center. + a.dy -= outboxHeight / 2; + a.labelOffset = params.labelOffset; + }); + // Creating scales for touch point between the in-annotation edges + // and their hosts. + let inTouchHeight = Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, inboxHeight / 2); + inTouchHeight = inTouchHeight < 0 ? 0 : inTouchHeight; + let inY = d3 + .scaleLinear() + .domain([0, inAnnotations.length - 1]) + .range([-inTouchHeight, inTouchHeight]); + // Calculate annotation edge position + _.each(inAnnotations, (a, i) => { + a.points = [ + // The annotation node end + { + dx: a.dx + a.width / 2, + dy: a.dy, + }, + // The host node end + { + dx: -renderNodeInfo.coreBox.width / 2, + // only use scale if there are more than one, + // otherwise center it vertically + dy: inAnnotations.length > 1 ? inY(i) : 0, + }, + ]; + }); + // Creating scales for touch point between the out-annotation edges + // and their hosts. + let outTouchHeight = Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, outboxHeight / 2); + outTouchHeight = outTouchHeight < 0 ? 0 : outTouchHeight; + let outY = d3 + .scaleLinear() + .domain([0, outAnnotations.length - 1]) + .range([-outTouchHeight, outTouchHeight]); + _.each(outAnnotations, (a, i) => { + // Add point from the border of the annotation node + a.points = [ + // The host node end + { + dx: renderNodeInfo.coreBox.width / 2, + // only use scale if there are more than one, + // otherwise center it vertically + dy: outAnnotations.length > 1 ? outY(i) : 0, + }, + // The annotation node end + { + dx: a.dx - a.width / 2, + dy: a.dy, + }, + ]; + }); + renderNodeInfo.height = Math.max(renderNodeInfo.height, inboxHeight, outboxHeight); +} +/** + * Set size of an annotation node. + */ +function sizeAnnotation(a: render.Annotation): void { + switch (a.annotationType) { + case render.AnnotationType.CONSTANT: + _.extend(a, PARAMS.constant.size); + break; + case render.AnnotationType.SHORTCUT: + if (a.node.type === NodeType.OP) { + _.extend(a, PARAMS.shortcutSize.op); + } else if (a.node.type === NodeType.META) { + _.extend(a, PARAMS.shortcutSize.meta); + } else if (a.node.type === NodeType.SERIES) { + _.extend(a, PARAMS.shortcutSize.series); + } else if (a.node.type === NodeType.API_LIST) { + _.extend(a, PARAMS.shortcutSize.api_list); + } else { + throw Error('Invalid node type: ' + a.node.type); + } + break; + case render.AnnotationType.SUMMARY: + _.extend(a, PARAMS.constant.size); + break; + } +} +/** + * Determines the center position of the node's shape. The position depends + * on if the node has in and out-annotations. + */ +export function computeCXPositionOfNodeShape(renderInfo: render.RenderNodeInfo): number { + if (renderInfo.expanded) { + return renderInfo.x; + } + let dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0; + return renderInfo.x - renderInfo.width / 2 + dx + renderInfo.coreBox.width / 2; +} +/** Returns the angle (in degrees) between two points. */ +function angleBetweenTwoPoints(a: render.Point, b: render.Point): number { + let dx = b.x - a.x; + let dy = b.y - a.y; + return (180 * Math.atan(dy / dx)) / Math.PI; +} +/** + * Returns if a line going through the specified points is a straight line. + */ +function isStraightLine(points: render.Point[]) { + let angle = angleBetweenTwoPoints(points[0], points[1]); + for (let i = 1; i < points.length - 1; i++) { + let newAngle = angleBetweenTwoPoints(points[i], points[i + 1]); + // Have a tolerance of 1 degree. + if (Math.abs(newAngle - angle) > 1) { + return false; + } + angle = newAngle; + } + return true; +} +/** + * Returns the intersection of a line between the provided point + * and the provided rectangle. + */ +function intersectPointAndNode(point: render.Point, node: render.RenderNodeInfo): render.Point { + // cx and cy are the center of the rectangle. + let cx = node.expanded ? node.x : computeCXPositionOfNodeShape(node); + let cy = node.y; + // Calculate the slope + let dx = point.x - cx; + let dy = point.y - cy; + let w = node.expanded ? node.width : node.coreBox.width; + let h = node.expanded ? node.height : node.coreBox.height; + let deltaX, deltaY; + if ((Math.abs(dy) * w) / 2 > (Math.abs(dx) * h) / 2) { + // The intersection is above or below the rectangle. + if (dy < 0) { + h = -h; + } + deltaX = dy === 0 ? 0 : ((h / 2) * dx) / dy; + deltaY = h / 2; + } else { + // The intersection is left or right of the rectangle. + if (dx < 0) { + w = -w; + } + deltaX = w / 2; + deltaY = dx === 0 ? 0 : ((w / 2) * dy) / dx; + } + return { x: cx + deltaX, y: cy + deltaY }; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/loader.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/loader.ts new file mode 100644 index 0000000000000000000000000000000000000000..fcc7cdada4a405f9f1b701f588c69f4b192dbcf9 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/loader.ts @@ -0,0 +1,83 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as tb_debug from '../tb_debug'; +import * as tf_graph_common from './common'; +import * as tf_graph from './graph'; +import * as hierarchy from './hierarchy'; +import * as op from './op'; +import * as parser from './parser'; +import * as tf_graph_util from './util'; + +export type GraphAndHierarchy = { + graph: tf_graph.SlimGraph; + graphHierarchy: hierarchy.Hierarchy; +}; +export function fetchAndConstructHierarchicalGraph( + tracker: tf_graph_common.ProgressTracker, + remotePath: string | null, + pbTxtFile: Blob | null, + compatibilityProvider: op.CompatibilityProvider = new op.TpuCompatibilityProvider(), + hierarchyParams: hierarchy.HierarchyParams = hierarchy.DefaultHierarchyParams, +): Promise { + const dataTracker = tf_graph_util.getSubtaskTracker(tracker, 30, 'Data'); + const graphTracker = tf_graph_util.getSubtaskTracker(tracker, 20, 'Graph'); + const hierarchyTracker = tf_graph_util.getSubtaskTracker(tracker, 50, 'Namespace hierarchy'); + const start = Date.now(); + return parser + .fetchAndParseGraphData(remotePath!, pbTxtFile!, dataTracker) + .then( + function (graph) { + if (!graph.node) { + throw new Error( + 'The graph is empty. This can happen when ' + + 'TensorFlow could not trace any graph. Please refer to ' + + 'https://github.com/tensorflow/tensorboard/issues/1961 for more ' + + 'information.', + ); + } + return tf_graph.build(graph, tf_graph.DefaultBuildParams, graphTracker); + }, + () => { + throw new Error( + 'Malformed GraphDef. This can sometimes be caused by ' + + 'a bad network connection or difficulty reconciling multiple ' + + 'GraphDefs; for the latter case, please refer to ' + + 'https://github.com/tensorflow/tensorboard/issues/1929.', + ); + }, + ) + .then(async (graph) => { + // Populate compatibile field of OpNode based on whitelist + op.checkOpsForCompatibility(graph, compatibilityProvider); + const graphHierarchy = await hierarchy.build(graph, hierarchyParams, hierarchyTracker); + tf_graph_util.notifyDebugEvent({ + timingId: tb_debug.GraphDebugEventId.GRAPH_LOAD_SUCCEEDED, + eventValue: Date.now() - start, + }); + return { graph, graphHierarchy }; + }) + .catch((e) => { + // Generic error catch, for errors that happened outside + // asynchronous tasks. + const msg = `Graph visualization failed.\n\n${e}`; + tracker.reportError(msg, e); + tf_graph_util.notifyDebugEvent({ + timingId: tb_debug.GraphDebugEventId.GRAPH_LOAD_FAILED, + eventValue: Date.now() - start, + }); + // Don't swallow the error. + throw e; + }); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/minimap.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/minimap.ts new file mode 100644 index 0000000000000000000000000000000000000000..0e36d274896a8d603cb265cb96a424bf57baffa0 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/minimap.ts @@ -0,0 +1,359 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; + +const FRAC_VIEWPOINT_AREA: number = 0.8; +export class Minimap { + /** The minimap container. */ + private minimap: HTMLElement; + /** The canvas used for drawing the mini version of the svg. */ + private canvas: HTMLCanvasElement; + /** A buffer canvas used for temporary drawing to avoid flickering. */ + private canvasBuffer: HTMLCanvasElement; + private downloadCanvas: HTMLCanvasElement; + /** The minimap svg used for holding the viewpoint rectangle. */ + private minimapSvg: SVGSVGElement; + /** The rectangle showing the current viewpoint. */ + private viewpoint: SVGRectElement; + /** + * The scale factor for the minimap. The factor is determined automatically + * so that the minimap doesn't violate the maximum width/height specified + * in the constructor. The minimap maintains the same aspect ratio as the + * original svg. + */ + private scaleMinimap: number; + /** The main svg element. */ + private svg: SVGSVGElement; + /** The svg group used for panning and zooming the main svg. */ + private zoomG: SVGGElement; + /** The zoom behavior of the main svg. */ + private mainZoom: d3.ZoomBehavior; + /** The maximum width and height for the minimap. */ + private maxWandH: number; + /** The last translation vector used in the main svg. */ + private translate: [number, number]; + /** The last scaling factor used in the main svg. */ + private scaleMain: number; + /** The coordinates of the viewpoint rectangle. */ + private viewpointCoord: { + x: number; + y: number; + }; + /** The current size of the minimap */ + private minimapSize: { + width: number; + height: number; + }; + /** Padding (px) due to the main labels of the graph. */ + private labelPadding: number; + /** + * Constructs a new minimap. + * + * @param svg The main svg element. + * @param zoomG The svg group used for panning and zooming the main svg. + * @param mainZoom The main zoom behavior. + * @param minimap The minimap container. + * @param maxWandH The maximum width/height for the minimap. + * @param labelPadding Padding in pixels due to the main graph labels. + */ + constructor( + svg: SVGSVGElement, + zoomG: SVGGElement, + mainZoom: d3.ZoomBehavior, + minimap: HTMLElement, + maxWandH: number, + labelPadding: number + ) { + this.svg = svg; + this.labelPadding = labelPadding; + this.zoomG = zoomG; + this.mainZoom = mainZoom; + this.maxWandH = maxWandH; + let $shadowRoot = d3.select(minimap.shadowRoot as unknown as Element); + // The minimap will have 2 main components: the canvas showing the content + // and an svg showing a rectangle of the currently zoomed/panned viewpoint. + let $minimapSvg = $shadowRoot.select('svg'); + // Make the viewpoint rectangle draggable. + let $viewpoint = $minimapSvg.select('rect'); + let dragmove = (d) => { + this.viewpointCoord.x = (d3.event).x; + this.viewpointCoord.y = (d3.event).y; + this.updateViewpoint(); + }; + this.viewpointCoord = {x: 0, y: 0}; + let drag = d3.drag().subject(Object).on('drag', dragmove); + $viewpoint.datum(this.viewpointCoord as any).call(drag); + // Make the minimap clickable. + $minimapSvg.on('click', () => { + if ((d3.event).defaultPrevented) { + // This click was part of a drag event, so suppress it. + return; + } + // Update the coordinates of the viewpoint. + let width = Number($viewpoint.attr('width')); + let height = Number($viewpoint.attr('height')); + let clickCoords = d3.mouse($minimapSvg.node() as any); + this.viewpointCoord.x = clickCoords[0] - width / 2; + this.viewpointCoord.y = clickCoords[1] - height / 2; + this.updateViewpoint(); + }); + this.viewpoint = $viewpoint.node(); + this.minimapSvg = $minimapSvg.node(); + this.minimap = minimap; + this.canvas = $shadowRoot.select('canvas.first').node(); + this.canvasBuffer = ( + $shadowRoot.select('canvas.second').node() + ); + this.downloadCanvas = ( + $shadowRoot.select('canvas.download').node() + ); + d3.select(this.downloadCanvas).style('display', 'none'); + this.update(); + } + /** + * Updates the position and the size of the viewpoint rectangle. + * It also notifies the main svg about the new panned position. + */ + private updateViewpoint(): void { + // Update the coordinates of the viewpoint rectangle. + d3.select(this.viewpoint) + .attr('x', this.viewpointCoord.x) + .attr('y', this.viewpointCoord.y); + // Update the translation vector of the main svg to reflect the + // new viewpoint. + let mainX = (-this.viewpointCoord.x * this.scaleMain) / this.scaleMinimap; + let mainY = (-this.viewpointCoord.y * this.scaleMain) / this.scaleMinimap; + d3.select(this.svg).call( + this.mainZoom.transform, + d3.zoomIdentity.translate(mainX, mainY).scale(this.scaleMain) + ); + } + /** + * Takes a snapshot of the graph's image as a Blob. + */ + getImageBlob(): Promise { + return new Promise((resolve) => { + this.downloadCanvas.toBlob((blob) => { + resolve(blob!); + }, 'image/png'); + }); + } + /** + * Redraws the minimap. Should be called whenever the main svg + * was updated (e.g. when a node was expanded). + */ + update(): void { + let sceneSize: DOMRect | null = null; + try { + // Get the size of the entire scene. + sceneSize = this.zoomG.getBBox(); + if (sceneSize.width === 0) { + // There is no scene anymore. We have been detached from the dom. + return; + } + } catch (e) { + // Firefox produced NS_ERROR_FAILURE if we have been + // detached from the dom. + return; + } + let $svg = d3.select(this.svg); + // Read all the style rules in the document and embed them into the svg. + // The svg needs to be self contained, i.e. all the style rules need to be + // embedded so the canvas output matches the origin. + let stylesText = ''; + const anySvg = this.svg as any; + // MSEdge does not have `getRootNode`. In that case, manually get the root + // node. This is more brittle than the getRootNode as changing DOM structure + // will break this. + const rootNode = anySvg.getRootNode + ? anySvg.getRootNode() + : this.svg.parentNode; + const styleSheets = rootNode.styleSheets; + for (let k = 0; k < styleSheets.length; k++) { + try { + let cssRules = + (styleSheets[k]).cssRules || (styleSheets[k]).rules; + if (cssRules == null) { + continue; + } + for (let i = 0; i < cssRules.length; i++) { + // Remove tf-* selectors from the styles. + stylesText += + cssRules[i].cssText.replace(/ ?tf-[\w-]+ ?/g, '') + '\n'; + } + } catch (e: any) { + if (e.name !== 'SecurityError') { + throw e; + } + } + } + // Temporarily add the css rules to the main svg. + let svgStyle = $svg.append('style'); + svgStyle.text(stylesText); + // Temporarily remove the zoom/pan transform from the main svg since we + // want the minimap to show a zoomed-out and centered view. + let $zoomG = d3.select(this.zoomG); + let zoomTransform = $zoomG.attr('transform'); + $zoomG.attr('transform', null); + // https://github.com/tensorflow/tensorboard/issues/1598 + // Account for SVG content shift. SVGGraphicsElement.getBBox().width returns + // width in pixel value of very tight bounding box of non-empty content. + // Since we want to measure the sceneSize from the origin to the right most + // edge of the right most node, we need to account for distance from the + // origin to the left edge of the bounding box. + sceneSize.height += sceneSize.y; + sceneSize.width += sceneSize.x; + // Since we add padding, account for that here. + sceneSize.height += this.labelPadding * 2; + sceneSize.width += this.labelPadding * 2; + // Temporarily assign an explicit width/height to the main svg, since + // it doesn't have one (uses flex-box), but we need it for the canvas + // to work. + $svg.attr('width', sceneSize.width).attr('height', sceneSize.height); + // Since the content inside the svg changed (e.g. a node was expanded), + // the aspect ratio have also changed. Thus, we need to update the scale + // factor of the minimap. The scale factor is determined such that both + // the width and height of the minimap are <= maximum specified w/h. + this.scaleMinimap = + this.maxWandH / Math.max(sceneSize.width, sceneSize.height); + this.minimapSize = { + width: sceneSize.width * this.scaleMinimap, + height: sceneSize.height * this.scaleMinimap, + }; + // Update the size of the minimap's svg, the buffer canvas and the + // viewpoint rect. + d3.select(this.minimapSvg).attr(this.minimapSize); + d3.select(this.canvasBuffer).attr(this.minimapSize); + // Download canvas width and height are multiples of the style width and + // height in order to increase pixel density of the PNG for clarity. + const downloadCanvasSelection = d3.select(this.downloadCanvas); + downloadCanvasSelection.style('width', sceneSize.width); + downloadCanvasSelection.style('height', sceneSize.height); + downloadCanvasSelection.attr('width', 3 * sceneSize.width); + downloadCanvasSelection.attr('height', 3 * sceneSize.height); + if (this.translate != null && this.zoom != null) { + // Update the viewpoint rectangle shape since the aspect ratio of the + // map has changed. + requestAnimationFrame(() => this.zoom()); + } + // TODO(stephanwlee): Consider not mutating the original DOM then read it -- + // this may cause reflow. + // Serialize the main svg to a string which will be used as the rendering + // content for the canvas. + let svgXml = new XMLSerializer().serializeToString(this.svg); + // Now that the svg is serialized for rendering, remove the temporarily + // assigned styles, explicit width and height and bring back the pan/zoom + // transform. + svgStyle.remove(); + $svg.attr('width', null).attr('height', null); + $zoomG.attr('transform', zoomTransform); + let image = new Image(); + image.onload = () => { + // Draw the svg content onto the buffer canvas. + let context = this.canvasBuffer.getContext('2d'); + context?.clearRect( + 0, + 0, + this.canvasBuffer.width, + this.canvasBuffer.height + ); + context?.drawImage( + image, + 0, + 0, + this.minimapSize.width, + this.minimapSize.height + ); + requestAnimationFrame(() => { + // Hide the old canvas and show the new buffer canvas. + d3.select(this.canvasBuffer).style('display', null); + d3.select(this.canvas).style('display', 'none'); + // Swap the two canvases. + [this.canvas, this.canvasBuffer] = [this.canvasBuffer, this.canvas]; + }); + let downloadContext = this.downloadCanvas.getContext('2d'); + downloadContext?.clearRect( + 0, + 0, + this.downloadCanvas.width, + this.downloadCanvas.height + ); + downloadContext?.drawImage( + image, + 0, + 0, + this.downloadCanvas.width, + this.downloadCanvas.height + ); + }; + image.onerror = () => { + let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'}); + image.src = (URL as any).createObjectURL(blob); + }; + image.src = + 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml); + } + /** + * Handles changes in zooming/panning. Should be called from the main svg + * to notify that a zoom/pan was performed and this minimap will update it's + * viewpoint rectangle. + * + * @param translate The translate vector, or none to use the last used one. + * @param scale The scaling factor, or none to use the last used one. + */ + zoom(transform?: d3.ZoomTransform): void { + if (this.scaleMinimap == null) { + // Scene is not ready yet. + return; + } + // Update the new translate and scale params, only if specified. + if (transform) { + this.translate = [transform.x, transform.y]; + this.scaleMain = transform.k; + } + // Update the location of the viewpoint rectangle. + let svgRect = this.svg.getBoundingClientRect(); + let $viewpoint = d3.select(this.viewpoint); + this.viewpointCoord.x = + (-this.translate[0] * this.scaleMinimap) / this.scaleMain; + this.viewpointCoord.y = + (-this.translate[1] * this.scaleMinimap) / this.scaleMain; + let viewpointWidth = (svgRect.width * this.scaleMinimap) / this.scaleMain; + let viewpointHeight = (svgRect.height * this.scaleMinimap) / this.scaleMain; + $viewpoint + .attr('x', this.viewpointCoord.x) + .attr('y', this.viewpointCoord.y) + .attr('width', viewpointWidth) + .attr('height', viewpointHeight); + // Show/hide the minimap depending on the viewpoint area as fraction of the + // whole minimap. + let mapWidth = this.minimapSize.width; + let mapHeight = this.minimapSize.height; + let x = this.viewpointCoord.x; + let y = this.viewpointCoord.y; + let w = + Math.min(Math.max(0, x + viewpointWidth), mapWidth) - + Math.min(Math.max(0, x), mapWidth); + let h = + Math.min(Math.max(0, y + viewpointHeight), mapHeight) - + Math.min(Math.max(0, y), mapHeight); + let fracIntersect = (w * h) / (mapWidth * mapHeight); + if (fracIntersect < FRAC_VIEWPOINT_AREA) { + this.minimap.classList.remove('hidden'); + } else { + this.minimap.classList.add('hidden'); + } + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts new file mode 100644 index 0000000000000000000000000000000000000000..f5b471909df8b1f187e2eddc2fba0b64e6464fd0 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/node.ts @@ -0,0 +1,1461 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; +import * as _ from 'lodash'; +import * as tf_graph_common from './common'; +import { Class, FontSizeInPx, selectChild, selectOrCreateChild } from './common'; +import * as contextmenu from './contextmenu'; +import * as edge from './edge'; +import * as tf_graph from './graph'; +import * as tf_graph_proto from './proto'; +import { + BridgeNode, + getIncludeNodeButtonString, + Metanode, + MetanodeImpl, + Node, + NodeType, + OpNode, + OpNodeImpl, + SeriesNode, +} from './graph'; +import * as layout from './layout'; +import * as render from './render'; +import { RenderNodeInfo } from './render'; +import * as tf_graph_scene from './scene'; +import { positionEllipse, positionRect } from './scene'; +import { TfGraphScene } from './tf-graph-scene'; +import * as tf_graph_util from './util'; + +/** + * Select or Create a 'g.nodes' group to a given sceneGroup + * and builds a number of 'g.node' groups inside the group. + * + * Structure Pattern: + * + * + * + * + * ... + * + * + * ... + * + * + * + * + * node name + * + * + * + * + * ... + * + * + * + * @param sceneGroup selection of the container + * @param nodeData array of render node information to map + * @param sceneElement polymer element + * @return selection of the created nodeGroups + */ +let colorStorage: { [key: string]: string } = {}; +let matchStorage = []; +let unMatchStorage = []; +export function buildGroup(sceneGroup, nodeData: render.RenderNodeInfo[], sceneElement) { + let container = tf_graph_common.selectOrCreateChild(sceneGroup, 'g', Class.Node.CONTAINER); + // Select all children and join with data. + // (Note that all children of g.nodes are g.node) + let nodeGroups = (container as any) + .selectAll(function () { + return this.childNodes; + }) + .data(nodeData, (d) => { + // make sure that we don't have to swap shape type + return d.node.name + ':' + d.node.type; + }); + // ENTER + nodeGroups + .enter() + .append('g') + .attr('data-name', (d) => { + return d.node.name; + }) + .each(function (d) { + let nodeGroup = d3.select(this); + // index node group for quick stylizing + sceneElement.addNodeGroup(d.node.name, nodeGroup); + }) + .merge(nodeGroups) + // ENTER + UPDATE + .attr('class', (d) => { + return Class.Node.GROUP + ' ' + nodeClass(d); + }) + .each(function (d) { + let nodeGroup = d3.select(this); + // Add g.in-annotations (always add -- to keep layer order + // consistent.) + let inAnnotationBox = tf_graph_common.selectOrCreateChild(nodeGroup, 'g', Class.Annotation.INBOX); + buildGroupForAnnotation(inAnnotationBox, d.inAnnotations, d, sceneElement); + // Add g.out-annotations (always add -- to keep layer order + // consistent.) + let outAnnotationBox = tf_graph_common.selectOrCreateChild(nodeGroup, 'g', Class.Annotation.OUTBOX); + buildGroupForAnnotation(outAnnotationBox, d.outAnnotations, d, sceneElement); + // Build .shape first (background of the node). + let shape = buildShape(nodeGroup, d, Class.Node.SHAPE); + let shape1 = buildShape(nodeGroup, d, Class.Node.OUTER); + if (d.node.isGroupNode) { + addButton(shape, d, sceneElement); + } + if (d.node.isGroupNode) { + addButton(shape1, d, sceneElement); + } + addInteraction(shape, d, sceneElement); + addInteraction(shape1, d, sceneElement); + // Build subscene on the top. + subsceneBuild(nodeGroup, d, sceneElement); + // Build label last. Should be on top of everything else. + let label = labelBuild(nodeGroup, d, sceneElement); + // Do not add interaction to metanode labels as they live inside the + // metanode shape which already has the same interactions. + addInteraction(label, d, sceneElement, d.node.type === NodeType.META); + stylize(nodeGroup, d, sceneElement); + position(nodeGroup, d); + }); + // EXIT + nodeGroups + .exit() + .each(function (d) { + // remove all indices on remove + sceneElement.removeNodeGroup(d.node.name); + let nodeGroup = d3.select(this); + if (d.inAnnotations.list.length > 0) { + nodeGroup + .select('.' + Class.Annotation.INBOX) + .selectAll('.' + Class.Annotation.GROUP) + .each((a) => { + sceneElement.removeAnnotationGroup(a, d); + }); + } + if (d.outAnnotations.list.length > 0) { + nodeGroup + .select('.' + Class.Annotation.OUTBOX) + .selectAll('.' + Class.Annotation.GROUP) + .each((a) => { + sceneElement.removeAnnotationGroup(a, d); + }); + } + }) + .remove(); + return nodeGroups; +} +/** + * Update or remove the subscene of a render group node depending on whether it + * is a expanded. If the node is not a group node, this method has no effect. + * + * @param nodeGroup selection of the container + * @param renderNodeInfo the render information for the node. + * @param sceneElement polymer element. + * @return Selection of the subscene group, or null if node group does not have + * a subscene. Op nodes, bridge nodes and unexpanded group nodes will + * not have a subscene. + */ +function subsceneBuild(nodeGroup, renderNodeInfo: render.RenderGroupNodeInfo, sceneElement: TfGraphScene) { + if (renderNodeInfo.node.isGroupNode) { + if (renderNodeInfo.expanded) { + // Recursively build the subscene. + return buildGroupForScene(nodeGroup, renderNodeInfo, sceneElement, Class.Subscene.GROUP); + } + // Clean out existing subscene if the node is not expanded. + tf_graph_scene.selectChild(nodeGroup, 'g', Class.Subscene.GROUP).remove(); + } + return null; +} +/** + * Translate the subscene of the given node group + */ +function subscenePosition(nodeGroup, d: render.RenderNodeInfo) { + // Translate the subscene to the middle of the parent node in vertical direction. + let x0 = d.x - d.coreBox.width / 2; + let y0 = d.y - d.height / 2 + d.paddingTop; + let subscene = tf_graph_scene.selectChild(nodeGroup, 'g', Class.Subscene.GROUP); + tf_graph_scene.translate(subscene, x0, y0); +} +/** + * Add an expand/collapse button to a group node + * + * @param selection The group node selection. + * @param d Info about the node being rendered. + * @param sceneElement polymer element. + */ +function addButton(selection, d: render.RenderNodeInfo, sceneElement) { + let group = tf_graph_common.selectOrCreateChild(selection, 'g', Class.Node.BUTTON_CONTAINER); + tf_graph_common.selectOrCreateChild(group, 'circle', Class.Node.BUTTON_CIRCLE); + tf_graph_common.selectOrCreateChild(group, 'path', Class.Node.EXPAND_BUTTON).attr('d', 'M0,-2.2 V2.2 M-2.2,0 H2.2'); + tf_graph_common.selectOrCreateChild(group, 'path', Class.Node.COLLAPSE_BUTTON).attr('d', 'M-2.2,0 H2.2'); + (group as any).on('click', (d: any) => { + // Stop this event's propagation so that it isn't also considered a + // node-select. + (d3.event).stopPropagation(); + sceneElement.fire('node-toggle-expand', { name: d.node.name }); + }); + tf_graph_scene.positionButton(group, d); +} +/** + * Fire node-* events when the selection is interacted. + * + * @param disableInteraction When true, have the provided selection + * ignore all pointer events. Used for text labels inside of metanodes, which + * don't need interaction as their surrounding shape has interaction, and if + * given interaction would cause conflicts with the expand/collapse button. + */ +function addInteraction(selection, d: render.RenderNodeInfo, sceneElement: TfGraphScene, disableInteraction?: boolean) { + if (disableInteraction) { + selection.attr('pointer-events', 'none'); + return; + } + let contextMenuFunction = contextmenu.getMenu(sceneElement, getContextMenu(d.node, sceneElement), d); + let time = 0; + let timeOut; + let mouseMoved = false; + let startX, startY; + const movementThreshold = 5; + selection + .on('dblclick', (d) => { + clearTimeout(timeOut); + sceneElement.fire('node-toggle-expand', { name: d.node.name }); + }) + .on('mouseover', (d) => { + // don't send mouseover over expanded group, + // otherwise it is causing too much glitches + if (sceneElement.isNodeExpanded(d)) { + return; + } + sceneElement.fire('node-highlight', { name: d.node.name }); + }) + .on('mouseout', (d) => { + // don't send mouseover over expanded group, + // otherwise it is causing too much glitches + if (sceneElement.isNodeExpanded(d)) { + return; + } + sceneElement.fire('node-unhighlight', { name: d.node.name }); + }) + .on('mousedown', () => { + startX = d3.event.clientX; + startY = d3.event.clientY; + mouseMoved = false; // 重置标志变量 + }) + + // 监听鼠标抬起事件,检查是否超过浮动阈值 + .on('mouseup', () => { + const deltaX = Math.abs(d3.event.clientX - startX); + const deltaY = Math.abs(d3.event.clientY - startY); + if (deltaX > movementThreshold || deltaY > movementThreshold) { + mouseMoved = true; + } + }) + + .on('click', (d) => { + clearTimeout(timeOut); // 清除第一个单击事件 + if (mouseMoved) { + mouseMoved = false; // 重置标志变量 + return; + } + timeOut = setTimeout(function () { + sceneElement.fire('node-select', { name: d.node.name }); + }, time); + }) + + .on('contextmenu', (d, i) => { + sceneElement.fire('node-select', { name: d.node.name }); + contextMenuFunction.call(d, i); + }); +} +/** + * Returns the d3 context menu specification for the provided node. + */ +export function getContextMenu(node: Node, sceneElement) { + let menu = [ + { + title: (d): string => { + return getIncludeNodeButtonString(node.include); + }, + action: (elm, d, i) => { + sceneElement.fire('node-toggle-extract', { name: node.name }); + }, + }, + ]; + if (sceneElement.nodeContextMenuItems) { + // Add these additional context menu items. + menu = menu.concat(sceneElement.nodeContextMenuItems); + } + if (canBeInSeries(node)) { + menu.push({ + title: (d) => { + return getGroupSettingLabel(node); + }, + action: (elm, d, i) => { + sceneElement.fire('node-toggle-seriesgroup', { + name: getSeriesName(node), + }); + }, + }); + } + return menu; +} +/** Returns if a node can be part of a grouped series */ +export function canBeInSeries(node: Node) { + return getSeriesName(node) !== null; +} +/** + * Returns the name of the possible grouped series containing this node. + * Returns null if the node cannot be part of a grouped series of nodes. + */ +export function getSeriesName(node: Node) { + if (!node) { + return null; + } + if (node.type === NodeType.SERIES) { + return node.name; + } + if (node.type === NodeType.OP) { + let op = node; + return op.owningSeries; + } + return null; +} +/** + * Returns the SeriesNode that represents the series that the provided node + * is contained in (or itself if the provided node is itself a SeriesNode). + * Returns null if the node is not rendered as part of a series. + */ +function getContainingSeries(node: Node) { + let s: SeriesNode | null = null; + if (!node) { + return null; + } else if (node.type === NodeType.SERIES) { + s = node; + } else if (node.parentNode && node.parentNode.type === NodeType.SERIES) { + s = node.parentNode; + } + return s; +} +/** + * Returns the label for a button to toggle the group setting of the provided + * node. + */ +export function getGroupSettingLabel(node: Node) { + return tf_graph.getGroupSeriesNodeButtonString( + getContainingSeries(node) !== null ? tf_graph.SeriesGroupingType.GROUP : tf_graph.SeriesGroupingType.UNGROUP, + ); +} +/** + * Append svg text for label and assign data. + * @param nodeGroup + * @param renderNodeInfo The render node information for the label. + * @param sceneElement polymer element. + */ +function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInfo, sceneElement) { + let text = renderNodeInfo.displayName; + // Truncate long labels for unexpanded Metanodes. + let useFontScale = renderNodeInfo.node.type === (NodeType.META || NodeType.API_LIST) && !renderNodeInfo.expanded; + let label = tf_graph_common.selectOrCreateChild(nodeGroup, 'text', Class.Node.LABEL); + // Make sure the label is visually on top among its siblings. + let labelNode = label.node(); + labelNode.parentNode?.appendChild(labelNode); + label.attr('dy', '.35em').attr('text-anchor', 'middle'); + + // In tf-graph-scene styles, fontSizes are defined to vary from 6px to 9px. Since we + // do not want to invoke computedStyles or hardcode the fontSize that would be + // duplicated in styles, we are rounding it to 8px which does not cause any visible + // jank. + let fontSize = 8; + + switch (renderNodeInfo.node.type) { + case NodeType.META: + fontSize = renderNodeInfo.expanded ? FontSizeInPx.Node.EXPANDED_LABEL : FontSizeInPx.Node.SERIES_LABEL; + break; + case NodeType.OP: + fontSize = FontSizeInPx.Node.OP_LABEL; + break; + case NodeType.API_LIST: + fontSize = renderNodeInfo.expanded ? FontSizeInPx.Node.EXPANDED_LABEL : FontSizeInPx.Node.SERIES_LABEL; + break; + } + + if (useFontScale) { + if (text.length > sceneElement.maxMetanodeLabelLength) { + text = text.substr(0, sceneElement.maxMetanodeLabelLength - 2) + '…'; + } + let scale = getLabelFontScale(sceneElement); + label.attr('font-size', scale(text.length) + 'px'); + fontSize = scale(text.length); + } + let txtElement = >label.text(text); + enforceLabelWidth(txtElement, renderNodeInfo.node.type, fontSize, renderNodeInfo); + return label; +} + +/** + * This function shortens text which would exceed the maximum pixel width of + * a label. + * + * @param txtElementSelection The text element containing the label's text as d3 + * selection. + * @param nodeType The type of the node the label belongs to. If the node is + * an annotation, the value is -1. Label widths are defined in + * layout.PARAMS.nodeSize.{meta|op|...}.maxLabelWidth for nodes and + * layout.PARAMS.annotations.labelWidth for annotations. + * @param renderNodeInfo The render information about the node, required to + * determine whether META nodes are collapsed or expanded. + */ +export function enforceLabelWidth( + txtElementSelection: d3.Selection, + nodeType: NodeType | number, + fontSize: number, + renderNodeInfo?: render.RenderNodeInfo, +): any { + // Get text element itself and its on-screen width. + let txtNode = txtElementSelection.node(); + let labelContent = txtNode.textContent; + + // Get maximum length from settings. + let maxLength: number | null = null; + switch (nodeType) { + case NodeType.META: + if (renderNodeInfo && !renderNodeInfo.expanded) { + // Only trim text if + // node expanded. + maxLength = layout.PARAMS.nodeSize.meta.maxLabelWidth; + } + break; + case NodeType.API_LIST: + if (renderNodeInfo && !renderNodeInfo.expanded) { + // Only trim text if + // node expanded. + maxLength = layout.PARAMS.nodeSize.api_list.maxLabelWidth; + } + break; + case NodeType.OP: + maxLength = layout.PARAMS.nodeSize.op.maxLabelWidth; + break; + case -1: + maxLength = layout.PARAMS.annotations.maxLabelWidth; + break; + default: + break; + } + if (maxLength === null) return; + + txtNode.textContent = tf_graph_util.maybeTruncateString(txtNode.textContent!, fontSize, maxLength); + // Add tooltip with full name and return. + return txtElementSelection.append('title').text(labelContent!); +} +/** + * d3 scale used for sizing font of labels, used by labelBuild, + * initialized once by getLabelFontScale. + */ +let fontScale: any = null; +function getLabelFontScale(sceneElement) { + if (!fontScale) { + fontScale = d3 + .scaleLinear() + .domain([sceneElement.maxMetanodeLabelLengthLargeFont, sceneElement.maxMetanodeLabelLength]) + .range([sceneElement.maxMetanodeLabelLengthFontSize, sceneElement.minMetanodeLabelLengthFontSize]) + .clamp(true); + } + return fontScale; +} +/** + * Set label position of a given node group + */ +function labelPosition(nodeGroup, cx: number, cy: number, yOffset: number) { + tf_graph_scene + .selectChild(nodeGroup, 'text', Class.Node.LABEL) + .transition() + .attr('x', cx) + .attr('y', cy + yOffset); +} +/** + * Select or append/insert shape for a node and assign renderNode + * as the shape's data. + * + * @param nodeGroup + * @param d Render node information. + * @param nodeClass class for the element. + * @return Selection of the shape. + */ +export function buildShape(nodeGroup, d, nodeClass: string): d3.Selection { + // Create a group to house the underlying visual elements. + let shapeGroup = tf_graph_common.selectOrCreateChild(nodeGroup, 'g', nodeClass); + // TODO: DOM structure should be templated in HTML somewhere, not JS. + switch (d.node.type) { + case NodeType.OP: + const opNode = d.node as OpNode; + if (_.isNumber(opNode.functionInputIndex) || _.isNumber(opNode.functionOutputIndex)) { + // This is input or output arg for a TensorFlow function. Use a special + // shape (a triangle) for them. + tf_graph_common.selectOrCreateChild(shapeGroup, 'polygon', Class.Node.COLOR_TARGET); + break; + } + tf_graph_common.selectOrCreateChild(shapeGroup, 'ellipse', Class.Node.COLOR_TARGET); + break; + case NodeType.SERIES: + // Choose the correct stamp to use to represent this series. + let stampType = 'annotation'; + let groupNodeInfo = d; + if (groupNodeInfo.coreGraph) { + stampType = groupNodeInfo.node.hasNonControlEdges ? 'vertical' : 'horizontal'; + } + let classList = [Class.Node.COLOR_TARGET]; + if (groupNodeInfo.isFadedOut) { + classList.push('faded-ellipse'); + } + tf_graph_common + .selectOrCreateChild(shapeGroup, 'use', classList) + .attr('xlink:href', '#op-series-' + stampType + '-stamp'); + tf_graph_common + .selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius) + .attr('ry', d.radius); + break; + case NodeType.BRIDGE: + tf_graph_common + .selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius) + .attr('ry', d.radius); + break; + case NodeType.META: + tf_graph_common + .selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius) + .attr('ry', d.radius); + break; + case NodeType.API_LIST: + tf_graph_common + .selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) + .attr('rx', d.radius) + .attr('ry', d.radius); + break; + default: + throw Error('Unrecognized node type : ' + d.node.type); + } + return shapeGroup; +} +export function nodeClass(d: render.RenderNodeInfo) { + switch (d.node.type) { + case NodeType.OP: + return Class.OPNODE; + case NodeType.META: + return Class.METANODE; + case NodeType.SERIES: + return Class.SERIESNODE; + case NodeType.BRIDGE: + return Class.BRIDGENODE; + case NodeType.ELLIPSIS: + return Class.ELLIPSISNODE; + case NodeType.API_LIST: + return Class.API_LIST; + } + throw Error('Unrecognized node type: ' + d.node.type); +} +/** Modify node and its subscene and its label's positional attributes */ +function position(nodeGroup, d: render.RenderNodeInfo) { + let shapeGroup = tf_graph_scene.selectChild(nodeGroup, 'g', Class.Node.SHAPE); + let shapeGroupHeader = tf_graph_scene.selectChild(nodeGroup, 'g', Class.Node.OUTER); + let cx = layout.computeCXPositionOfNodeShape(d); + switch (d.node.type) { + case NodeType.OP: { + // position shape + const opNode = d.node as OpNode; + if (_.isNumber(opNode.functionInputIndex) || _.isNumber(opNode.functionOutputIndex)) { + // This shape represents the input into or output out of a TensorFlow + // function. + let shape = tf_graph_scene.selectChild(shapeGroup, 'polygon'); + tf_graph_scene.positionTriangle(shape, d.x, d.y, d.coreBox.width, d.coreBox.height); + } else { + let shape = tf_graph_scene.selectChild(shapeGroup, 'ellipse'); + tf_graph_scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + } + labelPosition(nodeGroup, cx, d.y, d.labelOffset); + break; + } + case NodeType.META: { + // position shape + let shapes = shapeGroup.selectAll('rect'); + let INSIDE_RECT_OFFSET = 0; + // y值为定位值 取值为(15(色块高度)/2 (中心点y值) + 0.6 (向下偏移值, 不偏移会覆盖边及其颜色))得出 + // 偏移值为8.1 + let OFFSET_VALUE = 8.1; + if (d.expanded) { + tf_graph_scene.positionRect(shapes, d.x, d.y, d.width, d.height); + INSIDE_RECT_OFFSET = d.y - d.height / 2 + OFFSET_VALUE; + subscenePosition(nodeGroup, d); + // Put the label on top. + labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2); + } else { + tf_graph_scene.positionRect(shapes, cx, d.y, d.coreBox.width, d.coreBox.height); + // Place the label in the middle. + labelPosition(nodeGroup, cx, d.y, 0); + } + let shapesHeader = shapeGroupHeader.selectAll('rect'); + if (d.expanded) { + tf_graph_scene.positionRect(shapesHeader, d.x, INSIDE_RECT_OFFSET, d.width - 1, 15); + subscenePosition(nodeGroup, d); + // Put the label on top. + labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2); + } else { + tf_graph_scene.positionRect(shapesHeader, cx, d.y, d.coreBox.width, d.coreBox.height); + // Place the label in the middle. + labelPosition(nodeGroup, cx, d.y, 0); + } + break; + } + case NodeType.SERIES: { + let shape = tf_graph_scene.selectChild(shapeGroup, 'use'); + if (d.expanded) { + tf_graph_scene.positionRect(shape, d.x, d.y, d.width, d.height); + subscenePosition(nodeGroup, d); + // put label on top + labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2); + } else { + tf_graph_scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); + labelPosition(nodeGroup, cx, d.y, d.labelOffset); + } + break; + } + case NodeType.BRIDGE: { + // position shape + // NOTE: In reality, these will not be visible, but it helps to put them + // in the correct position for debugging purposes. + let shape = tf_graph_scene.selectChild(shapeGroup, 'rect'); + tf_graph_scene.positionRect(shape, d.x, d.y, d.width, d.height); + break; + } + case NodeType.API_LIST: { + // position shape + let shapes = shapeGroup.selectAll('rect'); + if (d.expanded) { + tf_graph_scene.positionRect(shapes, d.x, d.y, d.width, d.height); + subscenePosition(nodeGroup, d); + // Put the label on top. + labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2); + } else { + tf_graph_scene.positionRect(shapes, cx, d.y, d.coreBox.width, d.coreBox.height); + // Place the label in the middle. + labelPosition(nodeGroup, cx, d.y, 0); + } + break; + } + default: { + throw Error('Unrecognized node type: ' + d.node.type); + } + } +} + +function getGradient( + id: string, + colors: Array<{ + color: string; + proportion: number; + }>, + svgRoot?: SVGElement, +): string { + let escapedId = tf_graph_util.escapeQuerySelector(id); + if (!svgRoot) return `url(#${escapedId})`; + let $svgRoot = d3.select(svgRoot); + let gradientDefs = $svgRoot.select('defs#_graph-gradients'); + if (gradientDefs.empty()) { + gradientDefs = $svgRoot.append('defs').attr('id', '_graph-gradients'); + } + let linearGradient = gradientDefs.select('linearGradient#' + escapedId); + // If the linear gradient is not there yet, create it. + if (linearGradient.empty()) { + linearGradient = gradientDefs.append('linearGradient').attr('id', id); + // Re-create the stops of the linear gradient. + linearGradient.selectAll('*').remove(); + let cumulativeProportion = 0; + // For each color, create a stop using the proportion of that device. + _.each(colors, (d) => { + let color = d.color; + linearGradient.append('stop').attr('offset', cumulativeProportion).attr('stop-color', color); + linearGradient + .append('stop') + .attr('offset', cumulativeProportion + d.proportion) + .attr('stop-color', color); + cumulativeProportion += d.proportion; + }); + } + return `url(#${escapedId})`; +} +export function removeGradientDefinitions(svgRoot: SVGElement) { + d3.select(svgRoot).select('defs#_graph-gradients').remove(); +} + +function getColorByPrecisionIndex(precisionStr: string): string | undefined { + const precision = Number(precisionStr); + if (isNaN(precision)) { + return 'white'; + } + if (Object.entries(colorStorage).length !== 0) { + for (const [color, details] of Object.entries(colorStorage)) { + const detailsArray: any[] = [details]; + const [start, end] = detailsArray[0].value; + //进入md5模式 + if (start === end) { + if (precision === start) { + return color; + } + } + //对对比匹配成功的节点进行校验,如果precision大于1,按照最高等级处理,及颜色列表倒数第二个,倒数第一是未匹配节点的灰色 + else if (precision >= 1) { + const lastColor = Object.entries(colorStorage)[Object.entries(colorStorage).length - 2][0]; + return lastColor; + } + //其他区间模式, 最后一个区间的右侧一定为1,所以特化precision == 1的情况 + else if ((precision >= start && precision < end) || (precision === end && end === 1)) { + return color; + } + } + return 'white'; + } else { + const colorMap = [ + { precision: 0.2, color: '#ff704d' }, + { precision: 0.4, color: '#FFC62E' }, + { precision: 0.6, color: '#FFDC7F' }, + { precision: 0.8, color: '#FFEDBE' }, + { precision: 1, color: '#FFFCF3' }, + ]; + for (const range of colorMap) { + if (precision <= range.precision) { + return range.color; + } + } + return 'white'; + } +} +/** + * Returns the fill color for the node given its state and the 'color by' + * option. + * Takes in optional svgRoot, when passed, that populates SVG definitions + * for the fill inside the svgRoot when necessary. + */ +export function getFillForNode(renderInfo: render.RenderNodeInfo): string { + if (renderInfo.node instanceof OpNodeImpl || renderInfo.node instanceof MetanodeImpl) { + const precisionItem = renderInfo.node.attr.find((item) => item.key === tf_graph.PRECISION_INDEX); + const matchedNodeLink = renderInfo.node.matchedNodeLink; //标杆侧没有任何颜色 + const result = matchStorage.find((item) => Array.isArray(item) && item[0] === renderInfo.node.name); + const unmatch = unMatchStorage.find((item) => item === renderInfo.node.name); + if (result) { + return getColorByPrecisionIndex(result[1]); + } + if (unmatch) { + return '#C7C7C7'; + } + // 以前缀来判断是单图节点还是对比图节点 + if (renderInfo.node.name.startsWith('B___') || renderInfo.node.name.startsWith('N___')) { + if (!matchedNodeLink && !renderInfo.node.name.startsWith('B___')) { + return '#C7C7C7'; + } + } + if (precisionItem) { + return getColorByPrecisionIndex(precisionItem.value); + } else { + return 'transparent'; + } + } else { + // Other nodes are white. + return 'transparent'; + } +} +/** + * Modify node style by toggling class and assign attributes (only for things + * that can't be done in css). + */ +export function stylize(nodeGroup, renderInfo: render.RenderNodeInfo, sceneElement: TfGraphScene, nodeClass?) { + nodeClass = nodeClass || Class.Node.SHAPE || Class.Node.OUTER; + const isHighlighted = sceneElement.isNodeHighlighted(renderInfo.node.name); + const isSelected = sceneElement.isNodeSelected(renderInfo.node.name); + const isExtract = renderInfo.isInExtract || renderInfo.isOutExtract || renderInfo.isLibraryFunction; + const isExpanded = renderInfo.expanded && nodeClass !== Class.Annotation.NODE; + const isFadedOut = renderInfo.isFadedOut; + const isLinked = sceneElement.isNodeLinked(renderInfo.node.name); + nodeGroup.classed('highlighted', isHighlighted); + nodeGroup.classed('selected', isSelected); + nodeGroup.classed('extract', isExtract); + nodeGroup.classed('expanded', isExpanded); + nodeGroup.classed('faded', isFadedOut); + nodeGroup.classed('linked', isLinked); + // Main node always exists here and it will be reached before subscene, + // so d3 selection is fine here.OLOR_TARGET) + const node = nodeGroup.select(`.${nodeClass} .${Class.Node.COLOR_TARGET}`); + const outerNode = nodeGroup.select(`.${Class.Node.OUTER} .${Class.Node.COLOR_TARGET}`); + const fillColor = getFillForNode(renderInfo); + if (renderInfo.node.type === 0) { + node.style('fill', 'white'); + outerNode.style('fill', fillColor); + } else { + node.style('fill', fillColor); + } + // Choose outline to be darker version of node color if the node is a single + // color and is not selected. + node.style('stroke', isSelected ? null : getStrokeForFill(fillColor === 'transparent' ? 'white' : fillColor)); +} +/** + * Given a node's fill color/gradient, determine the stroke for the node. + */ +export function getStrokeForFill(fill: string) { + // If node is colored by a gradient, then use a dark gray outline. + if (fill.substring(0, 3) === 'url') { + return render.MetanodeColors.GRADIENT_OUTLINE; + } else if (fill.startsWith('rgba')) { + return 'rgb(167, 167, 167)'; + } else { + return d3.rgb(fill).darker().toString(); + } +} +/** + * Finds selected node and highlights all nodes which are providing direct + * or indirect input to the node and all edges connecting these nodes + * together and to the selected node. + */ +export function updateInputTrace( + svgRoot: SVGElement, + renderGraphInfo: render.RenderGraphInfo, + selectedNodeName: string, + traceInputs: Boolean, +) { + // Reset all styling. + const svg = d3.select(svgRoot); + svg.selectAll('.input-highlight').classed('input-highlight', false); + svg.selectAll('.non-input').classed('non-input', false); + svg.selectAll('.input-parent').classed('input-parent', false); + svg.selectAll('.input-child').classed('input-child', false); + svg.selectAll('.input-edge-highlight').classed('input-edge-highlight', false); + svg.selectAll('.non-input-edge-highlight').classed('non-input-edge-highlight', false); + svg.selectAll('.input-highlight-selected').classed('input-highlight-selected', false); + // Extract currently selected node. Return if input tracing disabled or no + // node is selected. + if (!renderGraphInfo || !traceInputs || !selectedNodeName) { + return; + } + let opNodes = _getAllContainedOpNodes(selectedNodeName, renderGraphInfo); + let allTracedNodes = {}; + _.each(opNodes, function (nodeInstance) { + allTracedNodes = traceAllInputsOfOpNode(svgRoot, renderGraphInfo, nodeInstance, allTracedNodes); + }); + // Highlight all parent nodes of each OpNode as input parent to allow + // specific highlighting. + let highlightedNodes = Object.keys(allTracedNodes); + let visibleNodes = _findVisibleParentsFromOpNodes(renderGraphInfo, highlightedNodes); + _markParentsOfNodes(svgRoot, visibleNodes); + // Attach class to all non-input nodes and edges for styling. + svg + .selectAll('g.node:not(.selected):not(.input-highlight)' + ':not(.input-parent):not(.input-children)') + .classed('non-input', true) + .each(function (d: RenderNodeInfo) { + // Mark all nodes with the specified name as non-inputs. This + // results in Annotation nodes which are attached to inputs to be + // tagged as well. + let nodeName = d.node.name; + svg.selectAll(`[data-name="${nodeName}"]`).classed('non-input', true); + }); + svg.selectAll('g.edge:not(.input-edge-highlight)').classed('non-input-edge-highlight', true); +} +/** + * Recursively find all op nodes contained by the node identified by the + * provided name. + * @param nodeName The meta or op node of which the OpNode instances are + * required. + * @param renderGraphInfo The rendered graph information object. + * @returns {Array} An array of OpNodeImpl instances. + */ +function _getAllContainedOpNodes(nodeName: string, renderGraphInfo: render.RenderGraphInfo): ReadonlyArray { + let opNodes: OpNodeImpl[] = []; + // Get current node. + let node = renderGraphInfo.getNodeByName(nodeName) as tf_graph.GroupNode | tf_graph.OpNode; + // If node is already OpNode then return the node plus its input embeddings. + if (node instanceof tf_graph.OpNodeImpl) { + return [node].concat(node.inEmbeddings); + } + // Otherwise, make recursive call for each node contained by the GroupNode. + let childNodeNames = (node as tf_graph.GroupNode).metagraph.nodes(); + _.each(childNodeNames, function (childNodeName) { + opNodes = opNodes.concat(_getAllContainedOpNodes(childNodeName, renderGraphInfo)); + }); + return opNodes; +} +/** + * When resolving inputs of a node the visible parent node of each input + * node (i.e. the first parent which is rendered to the screen) needs to be + * found, and since such a node may contain several input OpNodes a map + * of the visible parent to all the input OpNodes it contains is provided by + * opNodes. + */ +interface VisibleParent { + visibleParent: Node; + opNodes: OpNode[]; +} +function traceAllInputsOfOpNode( + svgRoot: SVGElement, + renderGraphInfo: render.RenderGraphInfo, + startNode: OpNode, + allTracedNodes: Object, +) { + // To prevent infinite loops due to cyclical relationships and improving + // performance by tracing OpNode which is input to 2+ nodes only once. + if (allTracedNodes[startNode.name]) { + return allTracedNodes; + } else { + allTracedNodes[startNode.name] = true; + } + // Extract the inputs. + let inputs = startNode.inputs; + // Get visible parent. + let currentVisibleParent = getVisibleParent(renderGraphInfo, startNode); + // Mark as input node. + d3.select(svgRoot).select(`.node[data-name="${currentVisibleParent.name}"]`).classed('input-highlight', true); + // Find the visible parent of each input. + let visibleInputs = {}; + _.each(inputs, function (nodeInstance) { + let resolvedNode = renderGraphInfo.getNodeByName(nodeInstance.name); + if (resolvedNode === undefined) { + // Node could not be found in rendered Hierarchy, which happens when + // tracing inputs of a SummaryNode. + return; + } + // Ensure node is resolved to OpNode if name collision with Metanode exists. + if (resolvedNode instanceof MetanodeImpl) { + let resolvedNodeName = tf_graph.getStrictName(resolvedNode.name); + resolvedNode = renderGraphInfo.getNodeByName(resolvedNodeName) as OpNode; + } + let visibleParent = getVisibleParent(renderGraphInfo, resolvedNode); + // Append OpNode to visible parent entry. + let visibleInputsEntry = visibleInputs[visibleParent.name]; + if (visibleInputsEntry) { + visibleInputsEntry.opNodes.push(resolvedNode); + } else { + // Create new entry. + visibleInputs[visibleParent.name] = { + visibleParent: visibleParent, + opNodes: [resolvedNode], + } as VisibleParent; + } + }); + // Find all parents of the start node. + let startNodeParents = {}; + let indexedStartNodeParents = [currentVisibleParent]; + startNodeParents[currentVisibleParent.name] = { + traced: false, + index: 0, + connectionEndpoints: [], + }; + let currentNode = currentVisibleParent as Node; + for (let index = 1; currentNode.name !== tf_graph.ROOT_NAME; index++) { + currentNode = currentNode.parentNode; + startNodeParents[currentNode.name] = { + traced: false, + index: index, + connectionEndpoints: [], + }; + indexedStartNodeParents[index] = currentNode; + } + // Find first mutual parent of each input node and highlight connection. + _.forOwn(visibleInputs, function (visibleParentInfo: VisibleParent, key) { + let nodeInstance = visibleParentInfo.visibleParent; + // Make recursive call for each input-OpNode contained by the visible + // parent. + _.each(visibleParentInfo.opNodes, function (opNode: OpNode) { + allTracedNodes = traceAllInputsOfOpNode(svgRoot, renderGraphInfo, opNode, allTracedNodes); + }); + if (nodeInstance.name !== currentVisibleParent.name) { + _createVisibleTrace(svgRoot, nodeInstance, startNodeParents, indexedStartNodeParents); + } + }); + return allTracedNodes; +} +/** + * Colors the edges to connect the passed node to the start node. This is + * done by: + * + * a) Finding the first (visible) common parent in the rendered + * hierarchy. + * NB: There are 2 types of connections: + * 1) Direct connections between node A + * and B, marked below as II, + * 2) Connections from any node A to its parent, A'. Marked below as I and III. + * For type 2 connection you need to know the inner-nested node, the + * direct parent, and the ultimate destination of the connection. + * + * A_parent B_parent + * +--------+ +---------+ + * | | | | + * | +--+ I| II |III+--+ | + * | |A +---------->+B | | + * | +--+ | | +--+ | + * | | | | + * +--------+ +---------+ + * + * + * b) Highlighting the direct connection between the parents of A and B, + * called A_parent and B_parent, s.t. A_parent and B_parent are children of the + * mutual parent of A and B found in a), marked above as II. + * + * c) Highlighting the connection from A to A_parent and B to B_parent + * (through all layers of parents between A and A_parent and B and B_parent, + * respectively). Marked above as I and III. + * + * @param nodeInstance The instance of the node to use as destination node, B. + * @param startNodeParents Map of startNodeParent names to information objects + * about the parent. + * @param indexedStartNodeParents An array of all parents of the start node. + * This is required to find the child of the mutual parent which is a parent + * of the start node. + * @private + */ +function _createVisibleTrace( + svgRoot: SVGElement, + nodeInstance: Node, + startNodeParents, + indexedStartNodeParents: Node[], +) { + let currentNode = nodeInstance; + let previousNode = nodeInstance; + // Ascend through parents until a mutual parent is found with the start + // node. + let destinationParentPairs: Node[][] = []; + while (!startNodeParents[currentNode.name]) { + if (previousNode.name !== currentNode.name) { + destinationParentPairs.push([previousNode, currentNode]); + } + previousNode = currentNode; + currentNode = currentNode.parentNode; + } + // Connection between nodes is drawn between the parents of each + // respective node, both of which share the mutual parent. + let startNodeIndex = startNodeParents[currentNode.name].index; + let startNodeName = indexedStartNodeParents[Math.max(startNodeIndex - 1, 0)].name; + let startNodeTopParentName = startNodeName; + let targetNodeTopParentName = previousNode.name; + let endNodeName = previousNode.name; + const svg = d3.select(svgRoot); + svg.selectAll(`[data-edge="${endNodeName}--${startNodeName}"]`).classed('input-edge-highlight', true); + // Trace up the parents of the input. + _.each(destinationParentPairs, function (value) { + let inner = value[0]; + let outer = value[1]; + let edgeSelector = `[data-edge="${inner.name}--${startNodeTopParentName}` + `~~${outer.name}~~OUT"]`; + svg.selectAll(edgeSelector).classed('input-edge-highlight', true); + }); + // Trace up the parents of the start node. + for (let index = 1; index < startNodeIndex; index++) { + let inner = indexedStartNodeParents[index - 1]; + let outer = indexedStartNodeParents[index]; + let edgeSelector = `[data-edge="${targetNodeTopParentName}~~${outer.name}` + `~~IN--${inner.name}"]`; + svg.selectAll(edgeSelector).classed('input-edge-highlight', true); + } +} +/** + * Creates map { [name: string] -> Node } of all visible / rendered parents + * of the nodes identified by the node names passed in. + * + * @param renderGraphInfo The information on the rendered graph. + * @param nodeNames String array of node names. + * @returns {[nodeName: string]: Node} + * @private + */ +function _findVisibleParentsFromOpNodes(renderGraphInfo, nodeNames: string[]) { + let visibleParents: { + [nodeName: string]: Node; + } = {}; + _.each(nodeNames, function (nodeName) { + let currentNode = renderGraphInfo.getNodeByName(nodeName); + let visibleParent = getVisibleParent(renderGraphInfo, currentNode); + visibleParents[visibleParent.name] = visibleParent; + }); + return visibleParents; +} +/** + * Traverse through the parents of all nodes in the list and mark each + * encountered node as input-parent. + * @param visibleNodes Map of input nodes, have to be visible/rendered when + * called. + * @private + */ +function _markParentsOfNodes( + svgRoot: SVGElement, + visibleNodes: { + [nodeName: string]: Node; + }, +) { + _.forOwn(visibleNodes, function (nodeInstance: Node) { + // Mark all parents of the node as input-parents. + let currentNode = nodeInstance; + while (currentNode.name !== tf_graph.ROOT_NAME) { + const renderedElementSelection = d3.select(svgRoot).select(`.node[data-name="${currentNode.name}"]`); + // Only mark the element as a parent node to an input if it is not + // marked as input node itself. + if ( + renderedElementSelection.nodes().length && + !renderedElementSelection.classed('input-highlight') && + !renderedElementSelection.classed('selected') && + // OpNode only parent if start node is embedded node, in which case + // the OpNode should be faded as well. + !renderedElementSelection.classed('op') + ) { + renderedElementSelection.classed('input-parent', true); + } + currentNode = currentNode.parentNode; + } + }); +} +/** + * Find the parent of the passed in op node which is expanded. This is done + * by going through all parents until the parent's parent is expanded, thus + * finding the first unexpanded parent which is rendered on the screen. + * @param renderGraphInfo The graph info object used to gain access to the + * render info of the parents. + * @param currentNode The node whose parent is to be found. + * @returns Node + */ +export function getVisibleParent(renderGraphInfo: render.RenderGraphInfo, currentNode: tf_graph.Node) { + let found = false; + let currentParent = currentNode; + while (!found) { + // Get parent element, to extract name. + currentNode = currentParent; + currentParent = currentNode.parentNode; + if (currentParent === undefined) { + found = true; + } else { + let renderNode = renderGraphInfo.getRenderNodeByName(currentParent.name); + // Found if node is rendered on the screen (renderNode truthy), and + // the parent is either expanded (i.e. it is a metanode or seriesnode) + // or the parent is an OpNode in which case currentNode is an embedded + // node which has another OpNode as parent. + if (renderNode && (renderNode.expanded || currentParent instanceof tf_graph.OpNodeImpl)) { + found = true; + } + } + } // Close while loop. + return currentNode; +} + +/** + * Annotations. + */ + +export function buildGroupForAnnotation( + container, + annotationData: render.AnnotationList, + d: render.RenderNodeInfo, + sceneElement, +) { + // Select all children and join with data. + let annotationGroups = container + .selectAll(function () { + // using d3's selector function + // See https://github.com/mbostock/d3/releases/tag/v2.0.0 + // (It's not listed in the d3 wiki.) + return this.childNodes; + }) + .data(annotationData.list, (d) => { + return d.node.name; + }); + annotationGroups + .enter() + .append('g') + .attr('data-name', (a) => { + return a.node.name; + }) + .each(function (a) { + let aGroup = d3.select(this); + // Add annotation to the index in the scene + sceneElement.addAnnotationGroup(a, d, aGroup); + // Append annotation edge + let edgeType = Class.Annotation.EDGE; + let metaedge = a.renderMetaedgeInfo && a.renderMetaedgeInfo.metaedge; + if (metaedge && !metaedge.numRegularEdges) { + edgeType += ' ' + Class.Annotation.CONTROL_EDGE; + } + // If any edges are reference edges, add the reference edge class. + if (metaedge && metaedge.numRefEdges) { + edgeType += ' ' + Class.Edge.REF_LINE; + } + edge.appendEdge(aGroup, a, sceneElement, edgeType); + if (a.annotationType !== render.AnnotationType.ELLIPSIS) { + addAnnotationLabelFromNode(aGroup, a); + buildShapeForAnnotation(aGroup, a); + } else { + addAnnotationLabel(aGroup, a.node.name, a, Class.Annotation.ELLIPSIS); + } + }) + .merge(annotationGroups) + .attr('class', (a) => { + return Class.Annotation.GROUP + ' ' + annotationToClassName(a.annotationType) + ' ' + nodeClass(a); + }) + .each(function (a) { + let aGroup = d3.select(this); + update(aGroup, d, a, sceneElement); + if (a.annotationType !== render.AnnotationType.ELLIPSIS) { + addInteractionForAnnotation(aGroup, d, a, sceneElement); + } + }); + annotationGroups + .exit() + .each(function (a) { + // Remove annotation from the index in the scene + sceneElement.removeAnnotationGroup(a, d); + }) + .remove(); + return annotationGroups; +} +/** + * Maps an annotation enum to a class name used in css rules. + */ +function annotationToClassName(annotationType: render.AnnotationType) { + return (render.AnnotationType[annotationType] || '').toLowerCase() || null; +} +function buildShapeForAnnotation(aGroup, a: render.Annotation) { + if (a.annotationType === render.AnnotationType.SUMMARY) { + let summary = selectOrCreateChild(aGroup, 'use'); + summary.attr('class', 'summary').attr('xlink:href', '#summary-icon').attr('cursor', 'pointer'); + } else { + let shape = buildShape(aGroup, a, Class.Annotation.NODE); + // add title tag to get native tooltips + selectOrCreateChild(shape, 'title').text(a.node.name); + } +} +function addAnnotationLabelFromNode(aGroup, a: render.Annotation) { + let namePath = a.node.name.split('/'); + let text = namePath[namePath.length - 1]; + return addAnnotationLabel(aGroup, text, a, null); +} +function addAnnotationLabel(aGroup, label: string, a: render.Annotation, additionalClassNames) { + let classNames = Class.Annotation.LABEL; + if (additionalClassNames) { + classNames += ' ' + additionalClassNames; + } + let txtElement = aGroup + .append('text') + .attr('class', classNames) + .attr('dy', '.35em') + .attr('text-anchor', a.isIn ? 'end' : 'start') + .text(label); + return enforceLabelWidth(txtElement, -1, FontSizeInPx.Annotation.LABEL); +} +function addInteractionForAnnotation(selection, d: render.RenderNodeInfo, annotation: render.Annotation, sceneElement) { + selection + .on('mouseover', (a) => { + sceneElement.fire('annotation-highlight', { + name: a.node.name, + hostName: d.node.name, + }); + }) + .on('mouseout', (a) => { + sceneElement.fire('annotation-unhighlight', { + name: a.node.name, + hostName: d.node.name, + }); + }) + .on('click', (a) => { + // Stop this event's propagation so that it isn't also considered a + // graph-select. + (d3.event).stopPropagation(); + sceneElement.fire('annotation-select', { + name: a.node.name, + hostName: d.node.name, + }); + }); + if ( + annotation.annotationType !== render.AnnotationType.SUMMARY && + annotation.annotationType !== render.AnnotationType.CONSTANT + ) { + selection.on('contextmenu', contextmenu.getMenu(sceneElement, getContextMenu(annotation.node, sceneElement)), d); + } +} +/** + * Adjust annotation's position. + * + * @param aGroup selection of a 'g.annotation' element. + * @param d Host node data. + * @param a annotation node data. + * @param sceneElement polymer element. + */ +function update(aGroup, d: render.RenderNodeInfo, a: render.Annotation, sceneElement) { + let cx = layout.computeCXPositionOfNodeShape(d); + // Annotations that point to embedded nodes (constants,summary) + // don't have a render information attached so we don't stylize these. + // Also we don't stylize ellipsis annotations (the string '... and X more'). + if (a.renderNodeInfo && a.annotationType !== render.AnnotationType.ELLIPSIS) { + stylize(aGroup, a.renderNodeInfo, sceneElement, Class.Annotation.NODE); + } + if (a.annotationType === render.AnnotationType.SUMMARY) { + // Update the width of the annotation to give space for the image. + a.width += 10; + } + // label position + aGroup + .select('text.' + Class.Annotation.LABEL) + .transition() + .attr('x', cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset)) + .attr('y', d.y + a.dy); + // Some annotations (such as summary) are represented using a 12x12 image tag. + // Purposely omitted units (e.g. pixels) since the images are vector graphics. + // If there is an image, we adjust the location of the image to be vertically + // centered with the node and horizontally centered between the arrow and the + // text label. + aGroup + .select('use.summary') + .transition() + .attr('x', cx + a.dx - 3) + .attr('y', d.y + a.dy - 6); + // Node position (only one of the shape selection will be non-empty.) + positionEllipse(aGroup.select('.' + Class.Annotation.NODE + ' ellipse'), cx + a.dx, d.y + a.dy, a.width, a.height); + positionRect(aGroup.select('.' + Class.Annotation.NODE + ' rect'), cx + a.dx, d.y + a.dy, a.width, a.height); + positionRect(aGroup.select('.' + Class.Annotation.NODE + ' use'), cx + a.dx, d.y + a.dy, a.width, a.height); + // Edge position + aGroup + .select('path.' + Class.Annotation.EDGE) + .transition() + .attr('d', (a) => { + // map relative position to absolute position + let points = a.points.map((p) => { + return { x: p.dx + cx, y: p.dy + d.y }; + }); + return edge.interpolate(points); + }); +} + +/** + * Scene. + */ + +/** + * Select or create a sceneGroup and build/update its nodes and edges. + * + * Structure Pattern: + * + * + * + * + * ... stuff from tf.graph.scene.edges.build ... + * + * + * ... stuff from tf.graph.scene.nodes.build ... + * + * + * + * + * ... stuff from tf.graph.scene.nodes.build ... + * + * + * + * + * ... stuff from tf.graph.scene.nodes.build ... + * + * + * + * + * @param container D3 selection of the parent. + * @param renderNode render node of a metanode or series node. + * @param sceneElement polymer element. + * @param sceneClass class attribute of the scene (default='scene'). + */ +export function buildGroupForScene( + container, + renderNode: render.RenderGroupNodeInfo, + sceneElement: TfGraphScene, + sceneClass?: string, +): d3.Selection { + sceneClass = sceneClass || Class.Scene.GROUP; + let isNewSceneGroup = selectChild(container, 'g', sceneClass).empty(); + let sceneGroup = selectOrCreateChild(container, 'g', sceneClass); + // core + let coreGroup = selectOrCreateChild(sceneGroup, 'g', Class.Scene.CORE); + let coreNodes = _.reduce( + renderNode.coreGraph.nodes(), + (nodes, name) => { + let node = renderNode.coreGraph.node(name); + if (!node.excluded) { + nodes.push(node); + } + return nodes; + }, + Array(), + ); + if (renderNode.node.type === NodeType.SERIES) { + // For series, we want the first item on top, so reverse the array so + // the first item in the series becomes last item in the top, and thus + // is rendered on the top. + coreNodes.reverse(); + } + // Create the layer of edges for this scene (paths). + edge.buildGroup(coreGroup, renderNode.coreGraph, sceneElement); + // Create the layer of nodes for this scene (ellipses, rects etc). + buildGroup(coreGroup, coreNodes, sceneElement); + // In-extract + if (renderNode.isolatedInExtract.length > 0) { + let inExtractGroup = selectOrCreateChild(sceneGroup, 'g', Class.Scene.INEXTRACT); + buildGroup(inExtractGroup, renderNode.isolatedInExtract, sceneElement); + } else { + selectChild(sceneGroup, 'g', Class.Scene.INEXTRACT).remove(); + } + // Out-extract + if (renderNode.isolatedOutExtract.length > 0) { + let outExtractGroup = selectOrCreateChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT); + buildGroup(outExtractGroup, renderNode.isolatedOutExtract, sceneElement); + } else { + selectChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT).remove(); + } + // Library functions + if (renderNode.libraryFunctionsExtract.length > 0) { + let outExtractGroup = selectOrCreateChild(sceneGroup, 'g', Class.Scene.FUNCTION_LIBRARY); + buildGroup(outExtractGroup, renderNode.libraryFunctionsExtract, sceneElement); + } else { + selectChild(sceneGroup, 'g', Class.Scene.FUNCTION_LIBRARY).remove(); + } + tf_graph_scene.position(sceneGroup, renderNode); + // Fade in the scene group if it didn't already exist. + if (isNewSceneGroup) { + sceneGroup.attr('opacity', 0).transition().attr('opacity', 1); + } + return sceneGroup; +} +export function getColors(colors) { + colorStorage = colors; +} + +export function getMatched(matched) { + matchStorage = matched; +} + +export function getUnMatched(unmatched) { + unMatchStorage = unmatched; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/op.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/op.ts new file mode 100644 index 0000000000000000000000000000000000000000..e8412cd93ec8874f4bfd0b19ca88e78e0563508e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/op.ts @@ -0,0 +1,588 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as _ from 'lodash'; +import {FUNCTION_LIBRARY_NODE_PREFIX, OpNode, SlimGraph} from './graph'; + +export interface CompatibilityProvider { + opValid: (opNode: OpNode) => boolean; +} +export class TpuCompatibilityProvider implements CompatibilityProvider { + /** + * Allowed list of current Tensorflow ops valid on the TPU. + * Note that some data types may be unsupported. + */ + static readonly WHITELIST = [ + 'Abs', + 'Acos', + 'Acosh', + 'Add', + 'AddN', + 'AddV2', + 'AdjustContrastv2', + 'AdjustHue', + 'AdjustSaturation', + 'All', + 'AllToAll', + 'Angle', + 'Any', + 'ApproximateEqual', + 'ArgMax', + 'ArgMin', + 'Asin', + 'Asinh', + 'Assert', + 'AssignAddVariableOp', + 'AssignSubVariableOp', + 'AssignVariableOp', + 'Atan', + 'Atan2', + 'Atanh', + 'AvgPool', + 'AvgPool3D', + 'AvgPool3DGrad', + 'AvgPoolGrad', + 'BatchMatMul', + 'BatchMatMulV2', + 'BatchToSpace', + 'BatchToSpaceND', + 'BesselI0e', + 'BesselI1e', + 'Betainc', + 'BiasAdd', + 'BiasAddGrad', + 'BiasAddV1', + 'Bitcast', + 'BitwiseAnd', + 'BitwiseOr', + 'BitwiseXor', + 'BroadcastArgs', + 'BroadcastGradientArgs', + 'BroadcastTo', + 'Bucketize', + 'Case', + 'Cast', + 'Ceil', + 'CheckNumerics', + 'Cholesky', + 'ClipByValue', + 'CollectivePermute', + 'CollectiveReduceV2', + 'Complex', + 'ComplexAbs', + 'Concat', + 'ConcatOffset', + 'ConcatV2', + 'Conj', + 'ConjugateTranspose', + 'Const', + 'ControlTrigger', + 'Conv2D', + 'Conv2DBackpropFilter', + 'Conv2DBackpropInput', + 'Conv3D', + 'Conv3DBackpropFilterV2', + 'Conv3DBackpropInputV2', + 'Cos', + 'Cosh', + 'Cross', + 'CrossReplicaSum', + 'Cumprod', + 'Cumsum', + 'DataFormatDimMap', + 'DataFormatVecPermute', + 'DepthToSpace', + 'DepthwiseConv2dNative', + 'DepthwiseConv2dNativeBackpropFilter', + 'DepthwiseConv2dNativeBackpropInput', + 'Dequantize', + 'DeviceIndex', + 'Diag', + 'DiagPart', + 'Digamma', + 'Div', + 'DivNoNan', + 'DynamicStitch', + 'Einsum', + 'Elu', + 'EluGrad', + 'Empty', + 'EmptyTensorList', + 'EnsureShape', + 'Equal', + 'Erf', + 'Erfc', + 'Erfinv', + 'Exp', + 'ExpandDims', + 'Expm1', + 'ExtractImagePatches', + 'FFT', + 'FFT2D', + 'FFT3D', + 'FakeParam', + 'FakeQuantWithMinMaxArgs', + 'FakeQuantWithMinMaxArgsGradient', + 'FakeQuantWithMinMaxVars', + 'FakeQuantWithMinMaxVarsGradient', + 'Fill', + 'Floor', + 'FloorDiv', + 'FloorMod', + 'FusedBatchNorm', + 'FusedBatchNormGrad', + 'FusedBatchNormGradV2', + 'FusedBatchNormGradV3', + 'FusedBatchNormV2', + 'FusedBatchNormV3', + 'Gather', + 'GatherNd', + 'GatherV2', + 'GetItem', + 'Greater', + 'GreaterEqual', + 'HSVToRGB', + 'IFFT', + 'IFFT2D', + 'IFFT3D', + 'IRFFT', + 'IRFFT2D', + 'IRFFT3D', + 'Identity', + 'IdentityN', + 'If', + 'Igamma', + 'IgammaGradA', + 'Igammac', + 'Imag', + 'InTopKV2', + 'InfeedDequeue', + 'InfeedDequeueTuple', + 'InplaceAdd', + 'InplaceUpdate', + 'Inv', + 'Invert', + 'InvertPermutation', + 'IsFinite', + 'IsInf', + 'IsNan', + 'KthOrderStatistic', + 'L2Loss', + 'LRN', + 'LRNGrad', + 'LeakyRelu', + 'LeakyReluGrad', + 'LeftShift', + 'Less', + 'LessEqual', + 'Lgamma', + 'LinSpace', + 'ListDiff', + 'Log', + 'Log1p', + 'LogSoftmax', + 'LogicalAnd', + 'LogicalNot', + 'LogicalOr', + 'LowerBound', + 'MakeUnique', + 'MatMul', + 'MatrixBandPart', + 'MatrixDiag', + 'MatrixDiagPart', + 'MatrixDiagPartV2', + 'MatrixDiagPartV3', + 'MatrixDiagV2', + 'MatrixDiagV3', + 'MatrixInverse', + 'MatrixSetDiag', + 'MatrixSetDiagV2', + 'MatrixSetDiagV3', + 'MatrixSolve', + 'MatrixTriangularSolve', + 'Max', + 'MaxPool', + 'MaxPool3D', + 'MaxPool3DGrad', + 'MaxPool3DGradGrad', + 'MaxPoolGrad', + 'MaxPoolGradGrad', + 'MaxPoolGradGradV2', + 'MaxPoolGradV2', + 'MaxPoolV2', + 'Maximum', + 'Mean', + 'Min', + 'Minimum', + 'MirrorPad', + 'MirrorPadGrad', + 'Mod', + 'Mul', + 'MulNoNan', + 'Multinomial', + 'Ndtri', + 'Neg', + 'NextAfter', + 'NoOp', + 'NonMaxSuppressionV4', + 'NotEqual', + 'OneHot', + 'OnesLike', + 'OutfeedEnqueue', + 'OutfeedEnqueueTuple', + 'Pack', + 'Pad', + 'PadV2', + 'ParallelDynamicStitch', + 'ParameterizedTruncatedNormal', + 'PartitionedCall', + 'PlaceholderWithDefault', + 'Polygamma', + 'PopulationCount', + 'Pow', + 'PreventGradient', + 'Prod', + 'Qr', + 'QuantizeAndDequantizeV2', + 'QuantizeAndDequantizeV3', + 'RFFT', + 'RFFT2D', + 'RFFT3D', + 'RGBToHSV', + 'RandomGammaGrad', + 'RandomShuffle', + 'RandomStandardNormal', + 'RandomUniform', + 'RandomUniformInt', + 'Range', + 'Rank', + 'ReadVariableOp', + 'Real', + 'RealDiv', + 'Reciprocal', + 'ReciprocalGrad', + 'Relu', + 'Relu6', + 'Relu6Grad', + 'ReluGrad', + 'Reshape', + 'ResizeBilinear', + 'ResizeBilinearGrad', + 'ResizeNearestNeighbor', + 'ResizeNearestNeighborGrad', + 'ResourceApplyAdaMax', + 'ResourceApplyAdadelta', + 'ResourceApplyAdagrad', + 'ResourceApplyAdagradDA', + 'ResourceApplyAdagradV2', + 'ResourceApplyAdam', + 'ResourceApplyAddSign', + 'ResourceApplyCenteredRMSProp', + 'ResourceApplyFtrl', + 'ResourceApplyFtrlV2', + 'ResourceApplyGradientDescent', + 'ResourceApplyKerasMomentum', + 'ResourceApplyMomentum', + 'ResourceApplyPowerSign', + 'ResourceApplyProximalAdagrad', + 'ResourceApplyProximalGradientDescent', + 'ResourceApplyRMSProp', + 'ResourceGather', + 'ResourceScatterAdd', + 'ResourceScatterDiv', + 'ResourceScatterMax', + 'ResourceScatterMin', + 'ResourceScatterMul', + 'ResourceScatterNdAdd', + 'ResourceScatterNdSub', + 'ResourceScatterNdUpdate', + 'ResourceScatterSub', + 'ResourceScatterUpdate', + 'ResourceStridedSliceAssign', + 'Reverse', + 'ReverseSequence', + 'ReverseV2', + 'RightShift', + 'Rint', + 'RngReadAndSkip', + 'RngSkip', + 'Roll', + 'Round', + 'Rsqrt', + 'RsqrtGrad', + 'ScatterNd', + 'Select', + 'SelectV2', + 'SelfAdjointEigV2', + 'Selu', + 'SeluGrad', + 'Shape', + 'ShapeN', + 'Sigmoid', + 'SigmoidGrad', + 'Sign', + 'Sin', + 'Sinh', + 'Size', + 'Slice', + 'Snapshot', + 'Softmax', + 'SoftmaxCrossEntropyWithLogits', + 'Softplus', + 'SoftplusGrad', + 'Softsign', + 'SoftsignGrad', + 'SpaceToBatch', + 'SpaceToBatchND', + 'SpaceToDepth', + 'SparseMatMul', + 'SparseSoftmaxCrossEntropyWithLogits', + 'SparseToDense', + 'Split', + 'SplitV', + 'Sqrt', + 'SqrtGrad', + 'Square', + 'SquaredDifference', + 'Squeeze', + 'StackCloseV2', + 'StackPopV2', + 'StackPushV2', + 'StackV2', + 'StatefulPartitionedCall', + 'StatefulStandardNormalV2', + 'StatefulTruncatedNormal', + 'StatefulUniform', + 'StatefulUniformFullInt', + 'StatefulUniformInt', + 'StatelessCase', + 'StatelessIf', + 'StatelessMultinomial', + 'StatelessRandomGetAlg', + 'StatelessRandomGetKeyCounter', + 'StatelessRandomGetKeyCounterAlg', + 'StatelessRandomNormal', + 'StatelessRandomNormalV2', + 'StatelessRandomUniform', + 'StatelessRandomUniformFullInt', + 'StatelessRandomUniformFullIntV2', + 'StatelessRandomUniformInt', + 'StatelessRandomUniformIntV2', + 'StatelessRandomUniformV2', + 'StatelessTruncatedNormal', + 'StatelessTruncatedNormalV2', + 'StatelessWhile', + 'StopGradient', + 'StridedSlice', + 'StridedSliceGrad', + 'Sub', + 'Sum', + 'Svd', + 'SymbolicGradient', + 'TPUEmbeddingActivations', + 'Tan', + 'Tanh', + 'TanhGrad', + 'TensorArrayCloseV3', + 'TensorArrayConcatV3', + 'TensorArrayGatherV3', + 'TensorArrayGradV3', + 'TensorArrayReadV3', + 'TensorArrayScatterV3', + 'TensorArraySizeV3', + 'TensorArraySplitV3', + 'TensorArrayV3', + 'TensorArrayWriteV3', + 'TensorListConcatV2', + 'TensorListElementShape', + 'TensorListFromTensor', + 'TensorListGather', + 'TensorListGetItem', + 'TensorListLength', + 'TensorListPopBack', + 'TensorListPushBack', + 'TensorListReserve', + 'TensorListSetItem', + 'TensorListSplit', + 'TensorListStack', + 'TensorScatterAdd', + 'TensorScatterMax', + 'TensorScatterMin', + 'TensorScatterSub', + 'TensorScatterUpdate', + 'TensorStridedSliceUpdate', + 'Tile', + 'TopKUnique', + 'TopKV2', + 'TopKWithUnique', + 'Transpose', + 'TridiagonalSolve', + 'TruncateDiv', + 'TruncateMod', + 'TruncatedNormal', + 'Unique', + 'Unpack', + 'UnsortedSegmentMax', + 'UnsortedSegmentMin', + 'UnsortedSegmentProd', + 'UnsortedSegmentSum', + 'UpperBound', + 'VarIsInitializedOp', + 'VariableShape', + 'Where', + 'While', + 'Xdivy', + 'XlaBroadcastHelper', + 'XlaConv', + 'XlaConvV2', + 'XlaDequantize', + 'XlaDot', + 'XlaDotV2', + 'XlaDynamicSlice', + 'XlaDynamicUpdateSlice', + 'XlaEinsum', + 'XlaGather', + 'XlaHostCompute', + 'XlaIf', + 'XlaKeyValueSort', + 'XlaPad', + 'XlaRecv', + 'XlaRecvFromHost', + 'XlaReduce', + 'XlaReduceWindow', + 'XlaReplicaId', + 'XlaScatter', + 'XlaSelectAndScatter', + 'XlaSelfAdjointEig', + 'XlaSend', + 'XlaSendToHost', + 'XlaSetBound', + 'XlaSetDynamicDimensionSize', + 'XlaSharding', + 'XlaSort', + 'XlaSpmdFullToShardShape', + 'XlaSpmdShardToFullShape', + 'XlaSvd', + 'XlaVariadicReduce', + 'XlaVariadicSort', + 'XlaWhile', + 'Xlog1py', + 'Xlogy', + 'ZerosLike', + 'Zeta', + + // Ops below are manually whitelisted and should not be evaluated for + // compatibility for various reasons. + + // Control flow ops. + 'Enter', + 'Exit', + 'LoopCond', + 'Merge', + 'NextIteration', + 'Switch', + // Ops below are inserted by the compiler. + '_Arg', + '_ArrayToList', + '_FusedBatchNormEx', + '_ListToArray', + '_ParallelConcatUpdate', + '_RecvTPUEmbeddingActivations', + '_RecvTPUEmbeddingDeduplicationData', + '_Retval', + '_SendTPUEmbeddingGradients', + '_TPUCompile', + '_TPUExecute', + '_UnaryOpsComposition', + // Distributed TPU ops. + 'TPUCompilationResult', + 'TPUReplicatedInput', + 'TPUReplicatedOutput', + 'TPUReplicateMetadata', + // Checkpointing ops. + 'MergeV2Checkpoints', + 'RestoreV2', + 'SaveV2', + // Miscellaneous CPU ops. + 'Abort', + 'Assert', + 'Assign', + 'Placeholder', + 'PlaceholderV2', + 'ShardedFilename', + 'StringJoin', + 'Variable', + 'VariableV2', + 'VarHandleOp', + // Summary ops. + 'AudioSummary', + 'AudioSummaryV2', + 'DebugNumericSummary', + 'HistogramSummary', + 'ImageSummary', + 'MergeSummary', + 'ScalarSummary', + 'StatsAggregatorSummary', + ]; + + /** + * Returns true if the node's inferred device is not the TPU. + * Note that this is only a best-effort check. + */ + private isNotTpuOp(opDevice: string): boolean { + if (opDevice.toLowerCase().search('cpu:') != -1) { + return true; + } + if (opDevice.toLowerCase().search('gpu:') != -1) { + return true; + } + return opDevice.toLowerCase().search('tpu') == -1; + } + opValid(opNode: OpNode): boolean { + // Function library nodes are generally for internal use. + if (opNode.name.search(FUNCTION_LIBRARY_NODE_PREFIX) == 0) { + return true; + } + // Nodes that lack op types should be ignored. + if (!opNode.op) { + return true; + } + // If assigned a device that is not TPU-related assume op is valid. + if (opNode.device && this.isNotTpuOp(opNode.device)) { + return true; + } + // If assigned to the TPU_SYSTEM device, assume op is valid. + if (opNode.device && opNode.device.search('TPU_SYSTEM') != -1) { + return true; + } + return _.includes(TpuCompatibilityProvider.WHITELIST, opNode.op); + } +} +export function checkOpsForCompatibility( + graph: SlimGraph, + provider: CompatibilityProvider +) { + if (provider === null) { + throw new Error('Compatibility provider required, but got: ' + provider); + } + _.each(graph.nodes, (node) => { + node.compatible = provider.opValid(node); + _.each(node.inEmbeddings, (node) => { + node.compatible = provider.opValid(node); + }); + _.each(node.outEmbeddings, (node) => { + node.compatible = provider.opValid(node); + }); + }); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/parser.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/parser.ts new file mode 100644 index 0000000000000000000000000000000000000000..7ee728395741e2d0eed422cf301ba795b8e3832f --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/parser.ts @@ -0,0 +1,348 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as tb_debug from '../tb_debug'; +import { ProgressTracker } from './common'; +import * as tf_graph_proto from './proto'; +import * as tf_graph_util from './util'; + +function parseValue(value: string): string | number | boolean { + if (value === 'true') { + return true; + } + if (value === 'false') { + return false; + } + let firstChar = value[0]; + if (firstChar === '"') { + return value.substring(1, value.length - 1); + } + let num = parseFloat(value); + return isNaN(num) ? value : num; +} +/** + * Fetches a text file and returns a promise of the result. + */ +export function fetchPbTxt(filepath: string): Promise { + return new Promise((resolve, reject) => { + fetch(filepath).then((res) => { + // Fetch does not reject for 400+. + if (res.ok) { + res.arrayBuffer().then(resolve, reject); + } else { + res.text().then(reject, reject); + } + }); + }); +} +/** + * Fetches the metadata file, parses it and returns a promise of the result. + */ +export function fetchAndParseMetadata(path: string, tracker: ProgressTracker) { + return tf_graph_util + .runTask( + 'Reading metadata pbtxt', + 40, + () => { + if (path == null) { + return Promise.resolve(null); + } + return fetchPbTxt(path); + }, + tracker, + tb_debug.GraphDebugEventId.FETCH_METADATA_PBTXT_BYTES, + ) + .then((arrayBuffer: ArrayBuffer | null) => { + return tf_graph_util.runAsyncPromiseTask( + 'Parsing metadata.pbtxt', + 60, + () => { + return arrayBuffer != null ? parseStatsPbTxt(arrayBuffer) : Promise.resolve(null); + }, + tracker, + tb_debug.GraphDebugEventId.PARSE_METADATA_PBTXT_INTO_OBJECT, + ); + }); +} +/** + * Fetches the graph file, parses it and returns a promise of the result. The + * result will be undefined if the graph is empty. + */ +export function fetchAndParseGraphData( + path: string, + pbTxtFile: Blob, + tracker: ProgressTracker, +): Promise { + return tf_graph_util + .runAsyncPromiseTask( + 'Reading graph pbtxt', + 40, + async () => { + const start = Date.now(); + if (pbTxtFile) { + const result = await new Promise(function (resolve, reject) { + let fileReader = new FileReader(); + fileReader.onload = () => resolve(fileReader.result as ArrayBuffer); + fileReader.onerror = () => reject(fileReader.error); + fileReader.readAsArrayBuffer(pbTxtFile); + }); + tf_graph_util.notifyDebugEvent({ + timingId: tb_debug.GraphDebugEventId.FETCH_PBTXT_BYTES_FROM_FILESYSTEM, + eventValue: Date.now() - start, + }); + return result; + } + + const result = await fetchPbTxt(path); + tf_graph_util.notifyDebugEvent({ + timingId: tb_debug.GraphDebugEventId.FETCH_PBTXT_BYTES_FROM_SERVER, + eventValue: Date.now() - start, + }); + return result; + }, + tracker, + tb_debug.GraphDebugEventId.FETCH_PBTXT_BYTES, + ) + .then((arrayBuffer: ArrayBuffer) => { + return tf_graph_util.runAsyncPromiseTask( + 'Parsing graph.pbtxt', + 60, + () => { + return parseGraphPbTxt(arrayBuffer); + }, + tracker, + tb_debug.GraphDebugEventId.PARSE_PBTXT_INTO_OBJECT, + ); + }); +} +/** + * Parse a file object in a streaming fashion line by line (or custom delim). + * Can handle very large files. + * @param input The file object as an array buffer. + * @param callback The callback called on each line + * @param chunkSize The size of each read chunk. (optional) + * @param delim The delimiter used to split a line. (optional) + * @returns Promise that resolves with true when it is finished. + */ +export function streamParse( + arrayBuffer: ArrayBuffer, + callback: (string) => void, + chunkSize: number = 1000000, + delim: string = '\n', +): Promise { + return new Promise(function (resolve, reject) { + function readChunk(oldData: string, newData: string, offset: number) { + const doneReading = offset >= arrayBuffer.byteLength; + const parts = newData.split(delim); + parts[0] = oldData + parts[0]; + // The last part may be part of a longer string that got cut off + // due to the chunking. + const remainder = doneReading ? '' : parts.pop(); + for (let part of parts) { + try { + callback(part); + } catch (e) { + reject(e); + return; + } + } + if (doneReading) { + resolve(true); + return; + } + const nextChunk = new Blob([arrayBuffer.slice(offset, offset + chunkSize)]); + const file = new FileReader(); + file.onload = function (e: any) { + readChunk(remainder!, e.target.result, offset + chunkSize); + }; + file.readAsText(nextChunk); + } + readChunk('', '', 0); + }); +} +/** + * Since proto-txt doesn't explicitly say whether an attribute is repeated + * (an array) or not, we keep a hard-coded list of attributes that are known + * to be repeated. This list is used in parsing time to convert repeated + * attributes into arrays even when the attribute only shows up once in the + * object. + * Repeated fields have to be in sync with graph.proto and all of its + * dependencies. + * See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto + */ +const GRAPH_REPEATED_FIELDS: { + [attrPath: string]: boolean; +} = { + 'library.function': true, + 'library.function.node_def': true, + 'library.function.node_def.input': true, + 'library.function.node_def.attr': true, + 'library.function.node_def.attr.value.list.b': true, + 'library.function.node_def.attr.value.list.f': true, + 'library.function.node_def.attr.value.list.func': true, + 'library.function.node_def.attr.value.list.i': true, + 'library.function.node_def.attr.value.list.s': true, + 'library.function.node_def.attr.value.list.shape': true, + 'library.function.node_def.attr.value.list.shape.dim': true, + 'library.function.node_def.attr.value.list.tensor': true, + 'library.function.node_def.attr.value.list.type': true, + 'library.function.node_def.attr.value.shape.dim': true, + 'library.function.node_def.attr.value.tensor.string_val': true, + 'library.function.node_def.attr.value.tensor.tensor_shape.dim': true, + 'library.function.signature.input_arg': true, + 'library.function.signature.output_arg': true, + 'library.versions': true, + node: true, + 'node.input': true, + 'node.attr.value.list.b': true, + 'node.attr.value.list.f': true, + 'node.attr.value.list.func': true, + 'node.attr.value.list.i': true, + 'node.attr.value.list.s': true, + 'node.attr.value.list.shape': true, + 'node.attr.value.list.shape.dim': true, + 'node.attr.value.list.tensor': true, + 'node.attr.value.list.type': true, + 'node.attr.value.shape.dim': true, + 'node.attr.value.tensor.string_val': true, + 'node.attr.value.tensor.tensor_shape.dim': true, +}; +const METADATA_REPEATED_FIELDS: { + [attrPath: string]: boolean; +} = { + 'step_stats.dev_stats': true, + 'step_stats.dev_stats.node_stats': true, + 'step_stats.dev_stats.node_stats.output': true, + 'step_stats.dev_stats.node_stats.memory': true, + 'step_stats.dev_stats.node_stats.output.tensor_description.shape.dim': true, +}; +/** + * Parses an ArrayBuffer of a proto txt file into a raw Graph object. + */ +export function parseGraphPbTxt(input: ArrayBuffer): Promise { + return parsePbtxtFile(input, GRAPH_REPEATED_FIELDS); +} +/** + * Parses an ArrayBuffer of a proto txt file into a StepStats object. + */ +export function parseStatsPbTxt(input: ArrayBuffer): Promise { + return parsePbtxtFile(input, METADATA_REPEATED_FIELDS).then((obj) => obj['step_stats']); +} +/** + * Parses a ArrayBuffer of a proto txt file into javascript object. + * + * @param input The ArrayBuffer or file object implementing slice. + * @param repeatedFields Map (Set) of all the repeated fields, since you can't + * tell directly from the pbtxt if a field is repeated or not. + * @returns The parsed object. + */ +function parsePbtxtFile( + input: ArrayBuffer, + repeatedFields: { + [attrPath: string]: boolean; + }, +): Promise { + let output: { + [name: string]: any; + } = {}; + let stack: {}[] = []; + let path: string[] = []; + let current: { + [name: string]: any; + } = output; + function splitNameAndValueInAttribute(line: string) { + let colonIndex = line.indexOf(':'); + let name = line.substring(0, colonIndex).trim(); + let value: any = parseValue(line.substring(colonIndex + 2).trim()); + if (name === 'input_data' || name === 'output_data') { + value = JSON.parse( + (value as string) + .replace(/'{/g, '{') + .replace(/}'/g, '}') + .replace(/'None'/g, '{"type": "None"}') + .replace(/'/g, '"'), + ) as object; + } else if (name === 'matched_node_link') { + value = JSON.parse((value as string).replace(/'/g, '"')) as string[]; + } else if (name === 'subnodes') { + value = JSON.parse((value as string).replace(/'/g, '"')) as string[]; + } else if (name === 'suggestions') { + value = JSON.parse((value as string).replace(/'{/g, '{').replace(/}'/g, '}').replace(/'/g, '"')) as object; + } + if (name === 'attr') { + const valueObj = JSON.parse((value as string).replace(/'/g, '"')) as object; + value = Object.keys(valueObj).map((key) => { + return { + key, + value: valueObj[key], + }; + }); + } + return { + name: name, + value: value, + }; + } + /** + * Adds a value, given the attribute name and the host object. If the + * attribute already exists, but is not an array, it will convert it to an + * array of values. + * + * @param obj The host object that holds the attribute. + * @param name The attribute name (key). + * @param value The attribute value. + * @param path A path that identifies the attribute. Used to check if + * an attribute is an array or not. + */ + function addAttribute(obj: Object, name: string, value: Object | string | number | boolean, path: string[]): void { + // We treat 'node' specially since it is done so often. + let existingValue = obj[name]; + if (existingValue == null) { + obj[name] = path.join('.') in repeatedFields ? [value] : value; + } else if (Array.isArray(existingValue)) { + existingValue.push(value); + } else { + obj[name] = [existingValue, value]; + } + } + // Run through the file a line at a time. + return streamParse(input, function (line: string) { + line = line.trim(); + if (!line) { + return; + } + switch (line[line.length - 1]) { + case '{': // create new object + let name = line.substring(0, line.length - 2).trim(); + let newValue: { + [name: string]: any; + } = {}; + stack.push(current); + path.push(name); + addAttribute(current, name, newValue, path); + current = newValue; + break; + case '}': + current = stack.pop()!; + path.pop(); + break; + default: + let x = splitNameAndValueInAttribute(line); + addAttribute(current, x.name, x.value, path.concat(x.name)); + break; + } + }).then(function () { + return output; + }); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/proto.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/proto.ts new file mode 100644 index 0000000000000000000000000000000000000000..317d2de528ccb7f12445939a84478d6db388aa71 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/proto.ts @@ -0,0 +1,244 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * @fileoverview Interfaces that parallel proto definitions in + * third_party/tensorflow/core/framework/... + * graph.proto + * step_stats.proto + * These should stay in sync. + * + * When adding a repeated field to this file, make sure to update the + * GRAPH_REPEATED_FIELDS and METADATA_REPEATED_FIELDS lists within parser.ts. + * Otherwise, the parser has no way of differentiating between a field with a + * certain value and a repeated field that has only 1 occurence, resulting in + * subtle bugs. + */ + +export enum NodeOpType { + MODULE = 0, + DEFAULT = 1, + API_LIST = 9, +} +/** Name of the node */ +export interface NodeDef { + name: string; + /** List of nodes that are inputs for this node. */ + input: string[]; + /** The name of the device where the computation will run. */ + device: string; + /** The name of the operation associated with this node. */ + op: string; + /** The op type of the node. */ + node_type: NodeOpType; + /** The array of inputs data in JSON string format. */ + input_data: { + [key: string]: any; + }; + /** The array of outputs data in JSON string format. */ + output_data: { + [key: string]: any; + }; + stack_info: []; + matched_node_link: []; + suggestions: { + [key: string]: string; + }; + /** The array consist of the path of linked node in graph comparison. */ + subnodes?: string[]; + isLeaf: boolean; + /** List of attributes that describe/modify the operation. */ + attr: { + key: string; + value: Object; + }[]; +} +export interface EdgeInfo { + input: string; + output: string; + shape: string; + attr: { + key: string; + value: Object; + }[]; +} +/** + * Describes a version of TensorFlow. + */ +export interface VersionDef { + // The version of the code that produced this data. + producer: number; + // Any consumer below this version is not allowed to consume this data. + min_consumer: number; + // Specific consumer versions which are disallowed (e.g. due to bugs). + bad_consumers: number[]; +} +/** + * Specifies an argument. An argument is either an input or an output of a + * function. There are thus 2 types of arguments: input_args and output_args. + * Nodes outside a function call connect to arguments. The graph explorer + * creates nodes for all arguments within a function. + */ +export interface ArgDef { + name: string; + type: string; +} +/** + * Describes the signature of a function - its name, inputs, and outputs. + */ +export interface OpDef { + name: string; + input_arg: ArgDef[]; + output_arg: ArgDef[]; +} +/** + * Describes a single function within the library. + */ +export interface FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + signature: OpDef; + // A list of nodes in the function. + node_def: NodeDef[]; +} +/** + * Describes a library of functions that may be composed throughout the graph. + */ +export interface FunctionDefLibraryDef { + // A list of functions. + function: FunctionDef[]; +} +/** + * TensorFlow graph definition as defined in the graph.proto file. + */ +export interface GraphDef { + // A list of nodes in the graph. + node: NodeDef[]; + // Compatibility versions of the graph. + versions: VersionDef[]; + // Contains a library of functions that may composed through the graph. + library: FunctionDefLibraryDef; + // The information of the list of edges. + edge: EdgeInfo[]; +} +/** + * Generic graph as defined in the graph_explorer.proto file. + */ +export interface GenericGraph { + /** List of nodes in the graph */ + node: GenericNode[]; + /** List of nodes in the graph */ + edge: GenericEdge[]; + /** List of attributes that describe/modify the operation. */ + attr: Array<{ + [key: string]: any; + }>; +} +/** + * GenericEdge corresponds to the Edge message in graph_explorer.proto. + */ +export interface GenericEdge { + /** Name of the source node. */ + source: string; + /** Name of the target node. */ + target: string; + /** Attributes of the edge. */ + edge_attr: Array<{ + [key: string]: any; + }>; +} +/** + * GenericNode corresponds to the Node message in graph_explorer.proto. + */ +export interface GenericNode { + /** Name of the node */ + name: string; + /** Attributes of a leaf node or leaf nodes within a metanode. */ + node_attr: Array<{ + [key: string]: any; + }>; + /** Attributes of a metanode. */ + metanode_attr: Array<{ + [key: string]: any; + }>; +} +export interface DevStat { + device: string; + node_stats: NodeExecStats[]; +} +/** + * TensorFlow stats file definition as defined in the stats proto file. + */ +export interface StepStats { + dev_stats: DevStat[]; +} +/** + * TensorFlow stats for a node as defined in the step_stats proto file. + */ +export interface NodeExecStats { + node_name: string; + // The next 4 properties are currently stored as string in json + // and must be parsed. + all_start_micros: number; + op_start_rel_micros: number; + op_end_rel_micros: number; + all_end_rel_micros: number; + memory: { + allocator_name: string; + total_bytes: number; // Stored as string in json and should be parsed. + peak_bytes: number; // Stored as string in json and should be parsed. + }[]; + /** Output sizes recorded for a single execution of a graph node */ + output: NodeOutput[]; + timeline_label: string; + scheduled_micros: string; + thread_id: string; +} +/** + * Description for the output tensor(s) of an operation in the graph as + * defined in the step_stats.proto file. + */ +export interface NodeOutput { + slot: number; // Stored as string in json and should be parsed. + tensor_description: { + /** Data type of tensor elements */ + dtype: string; + /** Shape of the tensor */ + shape: { + /** + * Dimensions of the tensor, such as [{name: 'input', size: 30}, + * {name: 'output', size: 40}] for a 30 x 40 2D tensor. The names + * are optional. The order of entries in 'dim' matters: It indicates + * the layout of the values in the tensor in-memory representation. + */ + dim: { + /** Size of the tensor in that dimension */ + size: number; // Stored as string in json and should be parsed. + /** Optional name of the tensor dimension */ + name?: string; + }[]; + }; + /** Information about the size and allocator used for the data */ + allocation_description: { + // The next 2 properties are stored as string in json and + // should be parsed. + /** Total number of bytes requested */ + requested_bytes: number; + /** Total number of bytes allocated, if known */ + allocated_bytes?: number; + /** Name of the allocator used */ + allocator_name: string; + }; + }; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/render.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/render.ts new file mode 100644 index 0000000000000000000000000000000000000000..f0b9da215a395bda67ca985db1ac3e3923e2cfc0 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/render.ts @@ -0,0 +1,2101 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * Package for the Render Hierarchy for TensorFlow graph. + */ +import * as d3 from 'd3'; +import { graphlib } from 'dagre'; +import * as _ from 'lodash'; +import * as tf_graph_common from './common'; +import * as tf_graph from './graph'; +import { + BridgeNode, + createGraph, + EllipsisNode, + getHierarchicalPath, + GraphType, + GroupNode, + InclusionType, + Metaedge, + MetaedgeImpl, + Metanode, + Node, + NodeType, + OpNode, + OpNodeImpl, +} from './graph'; +import { Hierarchy } from './hierarchy'; +import { NodeOpType } from './proto'; +import * as tf_graph_util from './util'; + +export type EdgeData = { + v: string; + w: string; + id: number; + label: RenderMetaedgeInfo; +}; + +export type Point = { + x: number; + y: number; +}; +/** + * Color parameters for op nodes. + */ +export let OpNodeColors = { + DEFAULT_FILL: '#ffffff', + DEFAULT_STROKE: '#b2b2b2', + COMPATIBLE: '#0f9d58', + INCOMPATIBLE: '#db4437', +}; +/** + * Color parameters for node encoding. + * @type {Object} + */ +export let MetanodeColors = { + /** + * Default fill and stroke to use when no other information is available. + */ + DEFAULT_FILL: '#d9d9d9', + DEFAULT_STROKE: '#a6a6a6', + SATURATION: 0.6, + LIGHTNESS: 0.85, + /** + * Neutral color to use when the node is expanded (used when coloring by + * compute time, memory and device). + */ + EXPANDED_COLOR: '#f0f0f0', + /** + * Standard hue values for node color palette. + */ + HUES: [220, 100, 180, 40, 20, 340, 260, 300, 140, 60], + STRUCTURE_PALETTE(id: number, lightened?: boolean) { + // The code below is a flexible way to computationally create a set + // of colors that go well together. + let hues = MetanodeColors.HUES; + let n = hues.length; + let hue = hues[id % n]; + let m = Math.sin((hue * Math.PI) / 360); + let sat = lightened ? 30 : 90 - 60 * m; + let light = lightened ? 95 : 80; + return d3.hsl(hue, 0.01 * sat, 0.01 * light).toString(); + }, + DEVICE_PALETTE(index: number): string { + return MetanodeColors.STRUCTURE_PALETTE(index); + }, + XLA_CLUSTER_PALETTE(index: number): string { + return MetanodeColors.STRUCTURE_PALETTE(index); + }, + UNKNOWN: '#eee', + GRADIENT_OUTLINE: '#888', +}; +/** + * Color parameters for op nodes. + */ +export let SeriesNodeColors = { + DEFAULT_FILL: 'white', + DEFAULT_STROKE: '#b2b2b2', +}; +/** + * Function that computes edge thickness in pixels. + */ +export interface EdgeThicknessFunction { + (edgeData: EdgeData, edgeClass: string): number; +} +/** + * Function that computes edge label strings. This function accepts a Metaedge, + * which could actually encapsulate several base edges. For instance, several + * base edges may merge into a single metaedge. + * + * To determine whether a metaedge represents several edges, check the length of + * its baseEdgeList property. + */ +export interface EdgeLabelFunction { + (metaedge: Metaedge, renderInfo: RenderGraphInfo): string; +} +/** + * Parameters that affect how the graph is rendered on the screen. + */ +const PARAMS = { + /** + * The minimum number of nodes for a graph to have in order for high in and + * out degree nodes to be extracted in auxiliary. The aim here is to prevent + * nodes from being extracted from small graphs. + */ + minNodeCountForExtraction: 15, + /** + * The minimum in or out degree a node must have in order to be possibly + * extracted. + */ + minDegreeForExtraction: 5, + /** + * Maximum number of control edges a node can have before they aren't + * displayed. + */ + maxControlDegree: 4, + /** + * Maximum in (for outbound bridge paths) or out (for inbound bridge paths) + * degree of a node allowed for a bridge path to be rendered to it from a + * subhierarchy of nodes. Having a max prevents having too many nodes emanate + * from a subhierarchy and crowding up. + */ + maxBridgePathDegree: 4, + /** + * Types patterns for predefined out-extract nodes, which are + * sink-like nodes that will be extracted from the main graph. + */ + outExtractTypes: ['NoOp'], + /** + * Types patterns for predefined in-extract nodes, which are + * source-like nodes that will be extracted from the main graph. + */ + inExtractTypes: [], + /** + * When removing edges from a high degree node, remove all of its edges if + * detachAllEdgesForHighDegree is true. Otherwise remove all in-edges if + * the node has high in-degree, or all out-edges if the node has high + * out-degree. + */ + detachAllEdgesForHighDegree: true, + /** + * After extracting high in/out degree nodes and predefined + * source-like/sink-like, extract isolated nodes to the side + * if this extractIsolatedNodesWithAnnotationsOnOneSide is true. + */ + extractIsolatedNodesWithAnnotationsOnOneSide: true, + /** + * Whether to add bridge nodes and edges to the core when building the + * subhierarchy of an expanded metanode. See buildSubhierarchy(). + */ + enableBridgegraph: true, + /** + * 2 colors, for the minimum and maximum value respectively, whenever we + * have a gradient scale. + */ + minMaxColors: ['#fff5f0', '#fb6a4a'], + /** + * Maximum number of annotations to be displayed on a node before an + * ellipsis is used. + */ + maxAnnotations: 5, +}; +/** + * The regular expression to use when parsing for the string that is + * used to label a function node in the graph. We strip away a prefix + * indicating that the node represents a function definition. We also + * remove an arbitrary hexadecimal suffix and the number following it + * if it is present. To be clear, we extract foo from + * __function_library__foo_deadb00f_42. + */ +const nodeDisplayNameRegex = new RegExp( + '^(?:' + tf_graph.FUNCTION_LIBRARY_NODE_PREFIX + ')?(\\w+)_[a-z0-9]{8}(?:_\\d+)?$', +); +/** + * Stores the rendering information, such as x and y coordinates, + * for each node in the graph. + */ +export class RenderGraphInfo { + hierarchy: Hierarchy; + private displayingStats: boolean; + private index: { + [nodeName: string]: RenderNodeInfo; + }; + renderedOpNames: string[]; + private toRenderEdgeObjs: { + v: string; + w: string; + id: number; + edge: tf_graph.Metaedge; + }[]; + deviceColorMap: d3.ScaleOrdinal; + memoryUsageScale: d3.ScaleLinear; + computeTimeScale: d3.ScaleLinear; + /** Scale for the thickness of edges when there is no shape information. */ + edgeWidthSizedBasedScale: d3.ScaleLinear | d3.ScalePower; + // Since the rendering information for each node is constructed lazily, + // upon node's expansion by the user, we keep a map between the node's name + // and whether the rendering information was already constructed for that + // node. + private hasSubhierarchy: { + [nodeName: string]: boolean; + }; + root: RenderGroupNodeInfo; + traceInputs: boolean; + edgeLabelFunction: EdgeLabelFunction; + // An optional function that computes the thickness of an edge given edge + // data. If not provided, defaults to encoding tensor size in thickness. + edgeWidthFunction: EdgeThicknessFunction; + constructor(hierarchy: Hierarchy, displayingStats: boolean) { + this.hierarchy = hierarchy; + this.displayingStats = displayingStats; + this.index = {}; + this.toRenderEdgeObjs = []; + this.renderedOpNames = []; + this.computeScales(); + // Maps node name to whether the rendering hierarchy was already + // constructed. + this.hasSubhierarchy = {}; + this.root = new RenderGroupNodeInfo(hierarchy.root, hierarchy.graphOptions); + this.index[hierarchy.root.name] = this.root; + this.renderedOpNames.push(hierarchy.root.name); + this.buildSubhierarchy(hierarchy.root.name); + this.root.expanded = true; + this.traceInputs = false; + } + computeScales() { + this.deviceColorMap = d3 + .scaleOrdinal() + .domain(this.hierarchy.devices) + .range(_.map(d3.range(this.hierarchy.devices.length), MetanodeColors.DEVICE_PALETTE)); + let topLevelGraph = this.hierarchy.root.metagraph; + // Find the maximum memory usage. Use 0 as the minimum. + let maxMemory = d3.max(topLevelGraph.nodes(), (nodeName, index) => { + let node = topLevelGraph.node(nodeName); + // Some ops don't have stats at all. + if (node.stats != null) { + return node.stats.totalBytes; + } + }); + this.memoryUsageScale = d3 + .scaleLinear() + .domain([0, maxMemory as unknown as number]) + .range(PARAMS.minMaxColors); + // Find the maximum compute time. Use 0 as the minimum. + let maxComputeTime = d3.max(topLevelGraph.nodes(), (nodeName, index) => { + let node = topLevelGraph.node(nodeName); + // Some ops don't have stats at all. + if (node.stats != null) { + return node.stats.getTotalMicros(); + } + }); + this.computeTimeScale = d3 + .scaleLinear() + .domain([0, maxComputeTime as unknown as number]) + .range(PARAMS.minMaxColors); + this.edgeWidthSizedBasedScale = tf_graph_common.EDGE_WIDTH_SIZE_BASED_SCALE; + } + /** + * Get a previously created RenderNodeInfo by its node name. + */ + getRenderNodeByName(nodeName: string): RenderNodeInfo { + return this.index[nodeName]; + } + /** + * Get the underlying node in the hierarchical graph by its name. + */ + getNodeByName(nodeName: string): Node { + return this.hierarchy.node(nodeName); + } + private colorHistogram( + histogram: { + [name: string]: number; + }, + colors: d3.ScaleOrdinal, + ): Array<{ + color: string; + proportion: number; + }> { + if (Object.keys(histogram).length > 0) { + // Compute the total # of items. + const numItems = _.sum(Object.keys(histogram).map((key) => histogram[key])); + return Object.keys(histogram).map((key) => ({ + color: colors(key), + // Normalize to a proportion of total # of items. + proportion: histogram[key] / numItems, + })); + } + return null!; + } + /** + * Get a previously created RenderNodeInfo for the specified node name, + * or create one if it hasn't been created yet. + */ + getOrCreateRenderNodeByName(nodeName: string): RenderNodeInfo { + // Polymer may invoke this with null. + if (!nodeName) { + return null!; + } + if (nodeName in this.index) { + return this.index[nodeName]; + } + let node = this.hierarchy.node(nodeName); + // Exit early if the node does not exist in the hierarchy. This can happen + // when a graph is reloaded while the infocard points to a node not visible + // at the top-level. + if (!node) { + return null!; + } + let renderInfo = node.isGroupNode + ? new RenderGroupNodeInfo(node, this.hierarchy.graphOptions) + : new RenderNodeInfo(node); + this.index[nodeName] = renderInfo; + this.renderedOpNames.push(nodeName); + if (node.stats) { + renderInfo.memoryColor = this.memoryUsageScale(node.stats.totalBytes); + renderInfo.computeTimeColor = this.computeTimeScale(node.stats.getTotalMicros()); + } + // We only fade nodes when we're displaying stats. + renderInfo.isFadedOut = this.displayingStats && !tf_graph_util.hasDisplayableNodeStats(node.stats); + var deviceHistogram: {} | null = null; + var xlaClusterHistogram: {} | null = null; + var opCompatibility: number | null = null; + if (node.isGroupNode) { + deviceHistogram = (node).deviceHistogram; + xlaClusterHistogram = (node).xlaClusterHistogram; + let compat = (node).compatibilityHistogram.compatible; + let incompat = (node).compatibilityHistogram.incompatible; + if (compat != 0 || incompat != 0) { + opCompatibility = compat / (compat + incompat); + } + } else { + let device = (renderInfo.node).device; + if (device) { + deviceHistogram = { [device]: 1 }; + } + let xlaCluster = (renderInfo.node).xlaCluster; + if (xlaCluster) { + xlaClusterHistogram = { [xlaCluster]: 1 }; + } + if (renderInfo.node.type === NodeType.OP) { + opCompatibility = (renderInfo.node).compatible ? 1 : 0; + } + } + if (deviceHistogram) { + renderInfo.deviceColors = this.colorHistogram(deviceHistogram, this.deviceColorMap); + } + if (opCompatibility != null) { + renderInfo.compatibilityColors = [ + { + color: OpNodeColors.COMPATIBLE, + proportion: opCompatibility, + }, + { + color: OpNodeColors.INCOMPATIBLE, + proportion: 1 - opCompatibility, + }, + ]; + } + return this.index[nodeName]; + } + /** + * Return the nearest ancestor node, including itself, that is visible + * in the visualization. This method is used so that we can select + * (highlight) a node that isn't drawn yet, by selecting (highlighting) + * its nearest ancestor that has been drawn. + */ + getNearestVisibleAncestor(name: string): string { + let path = getHierarchicalPath(name); + let i = 0; + let renderNode: RenderNodeInfo | null = null; + // Fallthrough. If everything was expanded return the node. + let nodeName = name; + for (; i < path.length; i++) { + nodeName = path[i]; + renderNode = this.getRenderNodeByName(nodeName); + // Op nodes have expanded set to false by default. + if (renderNode && !renderNode.expanded) { + break; + } + } + if (!renderNode) { + return ''; + } + // Check case where highlighted node is an embedded node whose parent node + // is also its hierarchical parent. In this case, we want to return the + // embedded node name, as it is also displayed if its parent has been + // displayed. + if (i == path.length - 2) { + let nextName = path[i + 1]; + if (renderNode?.inAnnotations.nodeNames[nextName]) { + return nextName; + } + if (renderNode?.outAnnotations.nodeNames[nextName]) { + return nextName; + } + } + return nodeName; + } + // TODO: Delete this an any code it touches (all deprecated). + setDepth(depth: number): void { + setGroupNodeDepth(this.root, +depth); + } + /** + * Returns true if the renderNode is an isolated node within its parent node. + */ + isNodeAuxiliary(renderNode: RenderNodeInfo): boolean { + let parentNode = this.getRenderNodeByName(renderNode.node.parentNode.name); + let found = _.find(parentNode.isolatedInExtract, (node) => { + return node.node.name === renderNode.node.name; + }); + if (found) { + return true; + } + found = _.find(parentNode.isolatedOutExtract, (node) => { + return node.node.name === renderNode.node.name; + }); + return !!found; + } + /** + * Returns a list of ops that have been rendered so far for this graph. More + * ops may later be rendered if the user expands nodes for instance. The list + * returned here can only stay the same size or grow on successive calls. + */ + getNamesOfRenderedOps(): string[] { + return this.renderedOpNames; + } + /** + * Clones an op node and adds it to a metagraph. Does nothing if an op node + * with the same new name has already been created within the metagraph. This + * method is used when duplicating a library function to be injected within a + * metanode representing a function call. + * @param parentMetanode The parent metanode on which to add the new node. + * @param node The op node to clone. + * @param newPrefix The prefix string to use in lieu of the one that merely + * indicates that the metanode represents a function defined in the + * library. This prefix should reflect graph hierarchy. + * @return The newly created op node (the clone of the original). + */ + private cloneAndAddFunctionOpNode( + parentMetanode: Metanode, + libraryFunctionNodeName: string, + node: OpNode, + newPrefix: string, + ): OpNode { + const newName = node.name.replace(libraryFunctionNodeName, newPrefix); + let newOpNode = parentMetanode.metagraph.node(newName) as any; + if (newOpNode) { + // This node had already been created and added to the graph. + return newOpNode; + } + // Create a new op node. + newOpNode = new OpNodeImpl({ + name: newName, + input: [], + device: node.device, + op: node.op, + input_data: _.cloneDeep(node.inputData), + output_data: _.cloneDeep(node.outputData), + stack_info: _.cloneDeep(node.stackData), + suggestions: _.cloneDeep(node.suggestions), + isLeaf: false, + attr: _.cloneDeep(node.attr), + node_type: NodeOpType.DEFAULT, + matched_node_link: _.cloneDeep(node.matchedNodeLink), + }); + // Update various properties. + newOpNode.cardinality = node.cardinality; + newOpNode.include = node.include; + newOpNode.outputShapes = _.cloneDeep(node.outputShapes); + newOpNode.xlaCluster = node.xlaCluster; + newOpNode.functionInputIndex = node.functionInputIndex; + newOpNode.functionOutputIndex = node.functionOutputIndex; + // Update the inputs of the new node to reflect the new path. + newOpNode.inputs = node.inputs.map((normalizedInput) => { + const newNormalizedInput = _.clone(normalizedInput); + newNormalizedInput.name = normalizedInput.name.replace(libraryFunctionNodeName, newPrefix); + return newNormalizedInput; + }); + // Add the new op node to the hierarchy and metagraph. Also add it to its + // parent metanode. + newOpNode.parentNode = parentMetanode; + parentMetanode.metagraph.setNode(newOpNode.name, newOpNode); + this.hierarchy.setNode(newOpNode.name, newOpNode); + // Update embeddings. + const updateEmbeddingOpNode = (embeddingNode) => { + return this.cloneAndAddFunctionOpNode(parentMetanode, libraryFunctionNodeName, embeddingNode, newPrefix); + }; + newOpNode.inEmbeddings = node.inEmbeddings.map(updateEmbeddingOpNode); + newOpNode.outEmbeddings = node.outEmbeddings.map(updateEmbeddingOpNode); + return newOpNode; + } + /** + * Clones a Metanode that represents a function defined in the graph library. + * We dynamically inject a clone of a function into a meta graph when the user + * expands a function call. We cannot do this at the beginning because the + * functions may recursively call themselves or other functions. + * @param metagraph The metagraph we are currently rendering the sub-hierarchy + * for. + * @param opNodeToReplace The op node in the graph to replace with a new + * (expandable) metanode that visualizes the innards of a function. + * @param libraryMetanode The metanode for a library function to clone. + * @param oldPrefix The old prefix to replace (that just reflects how this + * node is for a library function). + * @param newPrefix The prefix string to use in lieu of the one that merely + * indicates that the metanode represents a function defined in the + * library. This prefix should reflect graph hierarchy. + */ + private cloneFunctionLibraryMetanode( + metagraph: graphlib.Graph, + opNodeToReplace: OpNode, + libraryMetanode: Metanode, + oldPrefix: string, + newPrefix: string, + ): Metanode { + // Make a mapping between function output index and the new node for the + // output. + const functionOutputIndexToNode = {}; + const newMetanode = this.cloneFunctionLibraryMetanodeHelper( + metagraph, + opNodeToReplace, + libraryMetanode, + oldPrefix, + newPrefix, + functionOutputIndexToNode, + ); + if (!_.isEmpty(functionOutputIndexToNode)) { + // After we have cloned the edges within the metanode, we still must add + // edges that emanate out of output ops within the function. + this.patchEdgesFromFunctionOutputs(opNodeToReplace, functionOutputIndexToNode); + } + return newMetanode; + } + /** + * A helper subroutine that performs the bulk of the logic for + * `cloneFunctionLibraryMetanode`. + * @param metagraph The metagraph we are currently rendering the sub-hierarchy + * for. + * @param opNodeToReplace The op node in the graph to replace with a new + * (expandable) metanode that visualizes the innards of a function. + * @param libraryMetanode The metanode for a library function to clone. + * @param oldPrefix The old prefix to replace (that just reflects how this + * node is for a library function). + * @param newPrefix The prefix string to use in lieu of the one that merely + * indicates that the metanode represents a function defined in the + * library. This prefix should reflect graph hierarchy. + * @param functionOutputIndexToNode A mapping between function output index + * and the corresponding output node. Used to connect outputs with + * destinations outside of the function metanode. + */ + private cloneFunctionLibraryMetanodeHelper( + metagraph: graphlib.Graph, + opNodeToReplace: OpNode, + libraryMetanode: Metanode, + oldPrefix: string, + newPrefix: string, + functionOutputIndexToNode: { [key: string]: Node }, + ): Metanode { + const newMetanode = tf_graph.createMetanode(libraryMetanode.name.replace(oldPrefix, newPrefix)); + // Copy over various properties. + newMetanode.depth = libraryMetanode.depth; + newMetanode.cardinality = libraryMetanode.cardinality; + newMetanode.templateId = libraryMetanode.templateId; + newMetanode.opHistogram = _.clone(libraryMetanode.opHistogram); + newMetanode.deviceHistogram = _.clone(libraryMetanode.deviceHistogram); + newMetanode.xlaClusterHistogram = _.clone(libraryMetanode.xlaClusterHistogram); + newMetanode.hasNonControlEdges = libraryMetanode.hasNonControlEdges; + newMetanode.include = libraryMetanode.include; + newMetanode.nodeAttributes = _.clone(libraryMetanode.nodeAttributes); + newMetanode.associatedFunction = libraryMetanode.associatedFunction; + // Recursively duplicate the children nodes. + _.each(libraryMetanode.metagraph.nodes(), (nodeName) => { + const node = libraryMetanode.metagraph.node(nodeName); + switch (node.type) { + case NodeType.META: + // Recursively duplicate the metanode. + const newNode = this.cloneFunctionLibraryMetanodeHelper( + metagraph, + opNodeToReplace, + node as any, + oldPrefix, + newPrefix, + functionOutputIndexToNode, + ); + // Add the new node to the graph. + newNode.parentNode = newMetanode; + newMetanode.metagraph.setNode(newNode.name, newNode); + this.hierarchy.setNode(newNode.name, newNode); + break; + case NodeType.API_LIST: + // Recursively duplicate the metanode. + const newApiNodeList = this.cloneFunctionLibraryMetanodeHelper( + metagraph, + opNodeToReplace, + node as any, + oldPrefix, + newPrefix, + functionOutputIndexToNode, + ); + // Add the new node to the graph. + newApiNodeList.parentNode = newMetanode; + newMetanode.metagraph.setNode(newApiNodeList.name, newApiNodeList); + this.hierarchy.setNode(newApiNodeList.name, newApiNodeList); + break; + case NodeType.OP: + // Duplicate the op node. + const newOpNode = this.cloneAndAddFunctionOpNode(newMetanode, oldPrefix, node as any, newPrefix); + if (_.isNumber(newOpNode.functionInputIndex)) { + // This node represents an input_arg of the library function. Give + // it edges so that its bridge edges are created correctly. + this.patchEdgesIntoFunctionInputs(opNodeToReplace, newOpNode); + } + if (_.isNumber(newOpNode.functionOutputIndex)) { + functionOutputIndexToNode[newOpNode.functionOutputIndex] = newOpNode; + } + break; + default: + // This logic should never run because the meta graph should only + // contain meta and op nodes. + console.warn(node.name + ' is oddly neither a metanode nor an opnode.'); + } + }); + // Clone the edges within the function library metanode. + this.cloneLibraryMetanodeEdges(libraryMetanode, newMetanode, oldPrefix, newPrefix); + return newMetanode; + } + /** + * Clones the edges within `libraryMetanode` and adds them to `newMetanode`. + * The names of edge sources and destinations have their prefixes replaced + * with new prefixes that reflect their hierarchical positions in the graph + * instead of within the function library template. This is a subroutine for + * dynamically injecting a function metanode into the graph. + */ + private cloneLibraryMetanodeEdges( + libraryMetanode: Metanode, + newMetanode: Metanode, + oldPrefix: string, + newPrefix: string, + ) { + _.each(libraryMetanode.metagraph.edges(), (edgeObject) => { + const edge = libraryMetanode.metagraph.edge(edgeObject); + const newV = edge.v.replace(oldPrefix, newPrefix); + const newW = edge.w.replace(oldPrefix, newPrefix); + const newMetaEdge = new MetaedgeImpl(newV, newW); + // Duplicate various properties. + newMetaEdge.inbound = edge.inbound; + newMetaEdge.numRegularEdges = edge.numRegularEdges; + newMetaEdge.numRefEdges = edge.numRefEdges; + newMetaEdge.totalSize = edge.totalSize; + if (edge.baseEdgeList) { + newMetaEdge.baseEdgeList = edge.baseEdgeList.map((baseEdge) => { + const newBaseEdge = _.clone(baseEdge); + newBaseEdge.v = baseEdge.v.replace(oldPrefix, newPrefix); + newBaseEdge.w = baseEdge.w.replace(oldPrefix, newPrefix); + return newBaseEdge; + }); + } + // Set the direction of the edge based on whether it is inbound. The edge + // is inbound if its destination is within the metagraph. + if (newMetanode.metagraph.node(newW)) { + newMetanode.metagraph.setEdge(newV, newW, newMetaEdge); + } else { + newMetanode.metagraph.setEdge(newW, newV, newMetaEdge); + } + }); + } + /** + * When a metanode representing a function is cloned and placed into the + * graph, we must create edges between inputs into the function call and the + * input ops within the function. This function performs that patching. + */ + private patchEdgesIntoFunctionInputs(opNodeToReplace: OpNode, newOpNode: OpNode) { + // If the last few raw inputs are the same node, previous graph logic + // collapses them into a single normalized input. + let inputIndex = Math.min(newOpNode.functionInputIndex, opNodeToReplace.inputs.length - 1); + let newInput = _.clone(opNodeToReplace.inputs[inputIndex]); + // Clone the normalized input object. + newOpNode.inputs.push(newInput); + // Update values in the corresponding edge in the high-level + // metagraph. + const originalMetaEdges = this.hierarchy.getPredecessors(opNodeToReplace.name); + // Find the metaedge that the input index corresponds to. + // A metaedge may correspond to several edges. For instance, + // an edge may enter a series node. + let originalMetaEdge: Metaedge; + let regularEdgeCount = 0; + _.each(originalMetaEdges.regular, (metaEdge) => { + regularEdgeCount += metaEdge.numRegularEdges; + if (regularEdgeCount > inputIndex) { + originalMetaEdge = metaEdge; + // Terminate the loop. + return false; + } + }); + // Also change any base edges that point into the original node to + // point to the input arg within the function. These are used to + // make bridge edges. + _.each(originalMetaEdge!.baseEdgeList, (edge) => { + if (edge.w === opNodeToReplace.name) { + edge.w = newOpNode.name; + } + if (edge.v === opNodeToReplace.name) { + edge.v = newOpNode.name; + } + }); + } + /** + * When a metanode representing a function is cloned and placed into the + * graph, we must create edges between output ops within the new function + * metanode to its successors. This function does that after scanning the + * successors of the function call. + */ + private patchEdgesFromFunctionOutputs( + opNodeToReplace: OpNode, + functionOutputIndexToDestinationNode: { + [key: string]: Node; + }, + ) { + // Connect the outputs of the function to other ops. + const originalMetaEdges = this.hierarchy.getSuccessors(opNodeToReplace.name); + _.each(originalMetaEdges.regular, (metaedge) => { + _.each(metaedge.baseEdgeList, (baseEdge) => { + // Destination nodes within regular base edges are op nodes. + const destinationNode = this.hierarchy.node(baseEdge.w!) as OpNode; + _.each(destinationNode.inputs, (normalizedInput) => { + // If an output of the function is an input into the op, map it back + // to the output within the function so bridge edges are computed. + if (normalizedInput.name === opNodeToReplace.name) { + // Map the output tensor index (which in this case is for sure + // numeric because it is an output of a metanode) to the correct + // function output. + const outputNode = functionOutputIndexToDestinationNode[normalizedInput.outputTensorKey]; + normalizedInput.name = outputNode.name; + normalizedInput.outputTensorKey = baseEdge.outputTensorKey; + } + }); + }); + // Modify the list of base edges to point from the output so that bridge + // edges are correct. + _.each(metaedge.baseEdgeList, (baseEdge) => { + baseEdge.v = functionOutputIndexToDestinationNode[baseEdge.outputTensorKey].name; + baseEdge.outputTensorKey = '0'; + }); + }); + } + buildSubhierarchy(nodeName: string, subGraph: tf_graph.SlimGraph | undefined = undefined): void { + // Terminate if the rendering hierarchy was already constructed + // for this node. + if (nodeName in this.hasSubhierarchy) { + return; + } + // Record that we constructed the rendering hierarchy for this node, so we + // don't construct it another time. + this.hasSubhierarchy[nodeName] = true; + let renderNodeInfo = this.index[nodeName]; + // If it is not a meta node or a series node, don't do anything. + if ( + renderNodeInfo.node.type !== NodeType.META && + renderNodeInfo.node.type !== NodeType.API_LIST && + renderNodeInfo.node.type !== NodeType.SERIES + ) { + return; + } + // At this point we know the rendering information is about a group node. + let renderGroupNodeInfo = renderNodeInfo; + let metagraph = renderGroupNodeInfo.node.metagraph; + let coreGraph = renderGroupNodeInfo.coreGraph; + // Create render nodes to represent each child from the metagraph. Although + // these will initially be added to the coreGraph, they may later be + // extracted. Also, due to extraction, the coreGraph may contain disjoint + // groups between which there is no visible path (other than annotations). + _.each(metagraph.nodes(), (childName, index: number) => { + let childRenderInfo = this.getOrCreateRenderNodeByName(childName); + if (!childRenderInfo) { + return; + } + let childNode = childRenderInfo.node; + coreGraph.setNode(childName, childRenderInfo); + if (index >= 1 && subGraph && Object.keys(subGraph.metaNodes).length > 0) { + coreGraph.setEdge(metagraph.nodes()[index - 1], childName, {}); + } + if (!childNode.isGroupNode) { + _.each((childNode).inEmbeddings, (embedding) => { + let renderMetaedgeInfo = new RenderMetaedgeInfo(null!); + let renderNodeInfo = new RenderNodeInfo(embedding); + addInAnnotation(childRenderInfo, embedding, renderNodeInfo, renderMetaedgeInfo, AnnotationType.CONSTANT); + this.index[embedding.name] = renderNodeInfo; + }); + _.each((childNode).outEmbeddings, (embedding) => { + let renderMetaedgeInfo = new RenderMetaedgeInfo(null!); + let renderNodeInfo = new RenderNodeInfo(embedding); + addOutAnnotation(childRenderInfo, embedding, renderNodeInfo, renderMetaedgeInfo, AnnotationType.SUMMARY); + this.index[embedding.name] = renderNodeInfo; + }); + } + }); + // Add render metaedge info for edges in the metagraph. + _.each(metagraph.edges(), (edgeObj) => { + if (edgeObj.v.includes(tf_graph.NAMESPACE_DELIM) || edgeObj.w.includes(tf_graph.NAMESPACE_DELIM)) { + let inbound = edgeObj.w.includes(tf_graph.NAMESPACE_DELIM); + if (inbound) { + let pathNames = edgeObj.w.split(tf_graph.NAMESPACE_DELIM); + this.toRenderEdgeObjs.push({ + v: edgeObj.v, + w: pathNames[pathNames.length - 1], + id: edgeObj.name, + edge: metagraph.edge(edgeObj), + }); + } else { + let pathNames = edgeObj.v.split(tf_graph.NAMESPACE_DELIM); + this.toRenderEdgeObjs.push({ + v: pathNames[pathNames.length - 1], + w: edgeObj.w, + id: edgeObj.name, + edge: metagraph.edge(edgeObj), + }); + } + return; + } + let metaedge = metagraph.edge(edgeObj); + let renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge as any); + renderMetaedgeInfo.isFadedOut = this.index[edgeObj.v].isFadedOut || this.index[edgeObj.w].isFadedOut; + coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo, edgeObj.name); + }); + // Look up the parent node's render information and short circuit if none. + let parentNode = renderGroupNodeInfo.node.parentNode; + if (!parentNode) { + return; + } + let parentNodeInfo = this.index[parentNode.name]; + // Utility function for computing the name of a bridge node. + let getBridgeNodeName = (inbound, ...rest) => rest.concat([inbound ? 'IN' : 'OUT']).join('~~'); + if (subGraph) { + const subNodes = Object.keys(subGraph.metaNodes).concat(Object.keys(subGraph.nodes)); + renderNodeInfo.node.cardinality += subNodes.length; + const parentMetagraph = (parentNode as GroupNode).metagraph; + const parentBridgegraph = (parentNode as GroupNode).bridgegraph; + _.each(subNodes, (subName) => { + this.toRenderEdgeObjs + .filter((e) => e.v === subName || e.w === subName) + .forEach((item) => { + const edgeObj = item.edge; + const inbound = item.w === subName; + const edgeIO = inbound + ? [edgeObj.v?.split(tf_graph.NAMESPACE_DELIM)[0], nodeName] + : [nodeName, edgeObj.w?.split(tf_graph.NAMESPACE_DELIM)[0]]; + let bridgeMetaedge = parentMetagraph.edge(...edgeIO, item.id); + if (!bridgeMetaedge) { + bridgeMetaedge = parentBridgegraph.edge(...edgeIO, item.id); + if (!bridgeMetaedge) { + return; + } + } + _.each(edgeObj.baseEdgeList, (baseEdge) => { + let name = inbound ? baseEdge.v : baseEdge.w; + if (baseEdge.attr) { + baseEdge.attr['_path'] = name; + } else { + baseEdge.attr = { _path: name }; + } + if (inbound) { + baseEdge.v = name?.split(tf_graph.NAMESPACE_DELIM).pop(); + baseEdge.w = subName; + } else { + baseEdge.v = subName; + baseEdge.w = name?.split(tf_graph.NAMESPACE_DELIM).pop(); + } + bridgeMetaedge.addBaseEdge(baseEdge, this.hierarchy, true); + }); + }); + }); + } + // Remove rendered edges to save memory. + this.toRenderEdgeObjs = this.toRenderEdgeObjs.filter((e) => !this.index[e.v] || !this.index[e.w]); + // Build out the bridgegraph. + let bridgegraph = this.hierarchy.getBridgegraph(nodeName); + // Look for popular nodes so we can make annotations instead of paths. + let otherCounts = { + // Counts of edges coming INTO other nodes by name (outgoing from self). + in: < + { + [nodeName: string]: number; + } + >{}, + // Counts of edges going OUT from other nodes by name (coming into self). + out: < + { + [nodeName: string]: number; + } + >{}, + // Counts of all control edges involving other nodes by name. + control: < + { + [nodeName: string]: number; + } + >{}, + }; + _.each(bridgegraph.edges(), (e) => { + // An edge is inbound if its destination node is in the metagraph. + let inbound = !!metagraph.node(e.w); + let otherName = inbound ? e.v : e.w; + let metaedge = bridgegraph.edge(e); + if (!metaedge.numRegularEdges) { + otherCounts.control[otherName] = (otherCounts.control[otherName] || 0) + 1; + } else if (inbound) { + otherCounts.out[otherName] = (otherCounts.out[otherName] || 0) + 1; + } else { + otherCounts.in[otherName] = (otherCounts.in[otherName] || 0) + 1; + } + }); + // Add annotations and edges for bridgegraph relationships. + let hierarchyNodeMap = this.hierarchy.getNodeMap(); + // _.each(bridgegraph.edges(), (bridgeEdgeObj) => { + // let bridgeMetaedge = bridgegraph.edge(bridgeEdgeObj); + // // Determine whether this bridge edge is incoming by checking the + // // metagraph for a node that matches the destination end. + // let inbound = !!metagraph.node(bridgeEdgeObj.w); + // // Based on the direction of the edge, one endpoint will be an immediate + // // child of this renderNodeInfo, and the other endpoint will be a sibling + // // of the parent (or an ancestor further up). + // let [childName, otherName] = inbound + // ? [bridgeEdgeObj.w, bridgeEdgeObj.v] + // : [bridgeEdgeObj.v, bridgeEdgeObj.w]; + // let childRenderInfo = this.index[childName]; + // let otherRenderInfo = this.index[otherName]; + // let otherNode = otherRenderInfo + // ? otherRenderInfo.node + // : hierarchyNodeMap[otherName]; + // // Determine whether this edge is a control edge between nodes where + // // either node is high-degree with respect to control edges. This will + // // be a signal to show it as an annotation instead of a bridge edge. + // let isHighDegreeControlEdge = + // !bridgeMetaedge.numRegularEdges && + // otherCounts.control[otherName] > PARAMS.maxControlDegree; + // let [, childAnnotations] = inbound + // ? [renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] + // : [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations]; + // // Don't render a bridge path if the other node has in or out degree above + // // a threshold, lest bridge paths emanating out of a metagraph crowd up, + // // as was the case for the Fatcat LSTM lstm_1 > lstm_1 metagraph. + // let otherDegreeCount = (inbound ? otherCounts.out : otherCounts.in)[ + // otherName + // ]; + // let isOtherHighDegree = otherDegreeCount > PARAMS.maxBridgePathDegree; + // // The adjoining render metaedge info from the parent's coreGraph, if any. + // // It will either be a Metaedge involving this node directly, if it + // // previously came from a metagraph, or it'll be a Metaedge involving + // // a previously created bridge node standing in for the other node. + // let adjoiningMetaedge: RenderMetaedgeInfo | null = null; + // // We can only hope to render a bridge path if: + // // - bridgegraph paths are enabled, + // // - the other node is not too high-degree, + // // - the child is in the core (not extracted for being high-degree), and + // // - there's a path (in the traversal sense) between child and other. + // let canDrawBridgePath = false; + // if ( + // PARAMS.enableBridgegraph && + // !isOtherHighDegree && + // !isHighDegreeControlEdge && + // childRenderInfo.isInCore() + // ) { + // // Utility function for finding an adjoining metaedge. + // let findAdjoiningMetaedge = (targetName) => { + // let adjoiningEdgeObj = inbound + // ? { v: targetName, w: nodeName, name: bridgeEdgeObj.name } + // : { v: nodeName, w: targetName, name: bridgeEdgeObj.name }; + // return ( + // parentNodeInfo.coreGraph.edge(adjoiningEdgeObj) + // ); + // }; + // adjoiningMetaedge = findAdjoiningMetaedge(otherName); + // if (!adjoiningMetaedge) { + // adjoiningMetaedge = findAdjoiningMetaedge( + // getBridgeNodeName(inbound, otherName, parentNode.name) + // ); + // } + // canDrawBridgePath = !!adjoiningMetaedge; + // } + // // Although dataflow edges are acyclic, control dependency edges may + // // actually point 'backwards' in the graph. If this bridgeMetaedge is + // // a control dependency, we need to determine whether it's backwards + // // pointing so that we render it appropriately. + // // + // // For instance, say we're rendering a graph with nodes named A/B and Z/Y, + // // and we're currently rendering the bridgegraph for A. Further, let's say + // // that there was an original BaseEdge from A/B->Z/Y and a CONTROL EDGE + // // from Z/Y=>A/B. + // // + // // +----------------+ + // // | A | + // // | +-----+ | +------+ + // // | | B |>----->|>------->| Z | + // // | | | | | | + // // | | | * | | | + // // | | |<=====<|<=======<| | + // // | +-----+ | +------+ + // // +----------------+ + // // + // // When we render the subhierarchy for Metanode A, we'll come across a + // // control-only Metaedge in the bridgegraph from Z=>A/B (*). The question + // // is whether this edge is backwards. + // // + // // To answer that question, we follow the chain of adjoining metaedges + // // until we reach the topmost one. In this case, that's the control-only + // // Metaedge Z=>A in the ROOT's metagraph. We determine that this edge + // // is backwards by looking at the topological ordering of ROOT's metagraph + // // (which ignores control edges) and seeing that Z comes AFTER A. + // // + // // The property of being backwards is independent of whether the edge + // // is inbound or outbound. In the preceding example, if we were building + // // the subhierarchy for Z, we'd find bridge edge Z/Y=>A, walk to its + // // topmost adjoining metaedge Z=>A and discover that it's backwards. + // let backwards = false; + // if (adjoiningMetaedge && !bridgeMetaedge.numRegularEdges) { + // // Find the top-most adjoining render metaedge information, and the + // // GroupNode whose metagraph must contain the associated metaedge. + // let topAdjoiningMetaedge = adjoiningMetaedge; + // let topGroupNode = parentNodeInfo.node; + // while (topAdjoiningMetaedge.adjoiningMetaedge) { + // topAdjoiningMetaedge = topAdjoiningMetaedge.adjoiningMetaedge; + // topGroupNode = topGroupNode.parentNode; + // } + // // Check against the topological ordering for the top node. The current + // // bridge metaedge we're evaluating is backwards if its source comes + // // after its destination. + // let ordering = this.hierarchy.getTopologicalOrdering(topGroupNode.name); + // let e = topAdjoiningMetaedge.metaedge; + // backwards = ordering[e.v!] > ordering[e.w!]; + // } + // // Render backwards control edges as annotations. + // canDrawBridgePath = canDrawBridgePath && !backwards; + // // If we can't make a bridge path for any reason, then we add an + // // annotation instead. + // if (!canDrawBridgePath) { + // childAnnotations.push( + // new Annotation( + // otherNode, + // otherRenderInfo, + // new RenderMetaedgeInfo(bridgeMetaedge as any), + // AnnotationType.SHORTCUT, + // inbound + // ) + // ); + // return; + // } + // // At this point, all conditions have been met for drawing a bridge path. + // // Find or create the IN/OUT node representing otherNode. + // let bridgeContainerName = getBridgeNodeName(inbound, nodeName); + // let bridgeNodeName = getBridgeNodeName(inbound, otherName, nodeName); + // let bridgeNodeRenderInfo = coreGraph.node(bridgeNodeName); + // if (!bridgeNodeRenderInfo) { + // // Find or create the directional container for the bridge node. + // let bridgeContainerInfo = coreGraph.node(bridgeContainerName); + // if (!bridgeContainerInfo) { + // let bridgeContainerNode: BridgeNode = { + // // Important node properties. + // name: bridgeContainerName, + // type: NodeType.BRIDGE, + // // Unused node properties. + // isGroupNode: false, + // cardinality: 0, + // parentNode: null!, + // stats: null!, + // include: InclusionType.UNSPECIFIED, + // // BridgeNode properties. + // inbound: inbound, + // nodeAttributes: {}, + // }; + // bridgeContainerInfo = new RenderNodeInfo(bridgeContainerNode); + // this.index[bridgeContainerName] = bridgeContainerInfo as any; + // coreGraph.setNode(bridgeContainerName, bridgeContainerInfo); + // } + // let bridgeNode: BridgeNode = { + // // Important node properties. + // name: bridgeNodeName, + // type: NodeType.BRIDGE, + // // Unimportant node properties. + // isGroupNode: false, + // cardinality: 1, + // parentNode: null!, + // stats: null!, + // include: InclusionType.UNSPECIFIED, + // // BridgeNode properties. + // inbound: inbound, + // nodeAttributes: {}, + // }; + // bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode); + // this.index[bridgeNodeName] = bridgeNodeRenderInfo as any; + // coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo); + // // Set bridgeNode to be a graphlib child of the container node. + // coreGraph.setParent(bridgeNodeName, bridgeContainerName); + // bridgeContainerInfo.node.cardinality++; + // } + // // Create and add a bridge render metaedge. + // let bridgeRenderMetaedge = new RenderMetaedgeInfo(bridgeMetaedge as any); + // bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge!; + // inbound + // ? coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge, bridgeEdgeObj.name) + // : coreGraph.setEdge(childName, bridgeNodeName, bridgeRenderMetaedge, bridgeEdgeObj.name); + // }); // End _.each(bridgegraph.edges). + // For each bridge container (IN and/or OUT), add structural edges between + // terminal nodes and that container. A terminal node is one which has no + // non-bridge edges in the direction of the container. + // + // For example, consider a Metanode A which contains two child nodes A/B + // and A/C. Let's say it has one edge in the metagraph from A/B->A/C, and + // one edge in the bridgegraph from Z->A/C. + // + // At this point, we've added a container bridge node IN to house all + // incoming bridge nodes. We've also added a bridge node Z' (with parent IN) + // to A, and a bridge edge from Z'->C. + // + // +----------------------+ + // | A +---+ | + // | +------>| C | | + // | | +---+ | + // | | ^ | + // | | | | + // | | +----|----+ | + // | | | IN | | | + // | +---+ | +---+ | | + // | | B | | | Z'| | | + // | +---+ | +---+ | | + // | +---------+ | + // +----------------------+ + // + // With no other help, dagre would lay out B and Z' on the same level, + // because both of them have no incoming edges. In other words, B is a + // terminal node in the INCOMING direction. + // + // But we want to force dagre to lay out Z' (and everything in IN) lower + // than all non-bridge nodes, so that there's enough room for the bridge + // edges after they've been adjusted to meet up with paths coming in from + // outside. + // + // To force Z' (and all other bridge nodes) to be lowest in the graph, we + // identify terminal nodes like B and give them structural edges to + // a new structural bridge node S which we add to IN. + // + // +----------------------+ + // | A +---+ | + // | +--->| C | | + // | | +---+ | + // | +---+ ^ | + // | | B | | | + // | +---+ | | + // | ^ | | + // | | | | + // | +----|------|----+ | + // | |IN | | | | + // | | +---+ +---+ | | + // | | | S | | Z'| | | + // | | +---+ +---+ | | + // | +----------------+ | + // +----------------------+ + // + // This ensures that dagre will lay out the bridge containers strictly at + // the ends of the graph. The structural edges will never be seen in the + // visualization except as a debugging aid. + _.each([true, false], (inbound) => { + let bridgeContainerName = getBridgeNodeName(inbound, nodeName); + let bridgeContainerInfo = coreGraph.node(bridgeContainerName); + if (!bridgeContainerInfo) { + return; + } + _.each(coreGraph.nodes(), (childName) => { + // Short-circuit if this child is a bridge node or it's not a terminal + // node in the direction we're interested in. + let childNodeInfo = coreGraph.node(childName); + if (childNodeInfo.node.type === NodeType.BRIDGE) { + return; + } + let isTerminal = inbound + ? !coreGraph.predecessors(childName)?.length! + : !coreGraph.successors(childName)?.length!; + if (!isTerminal) { + return; + } + // Find or create a bridge node in the container for all structural + // metaedges. It would have been nice to skip this step and simply + // set a metaedge between the terminal node and the container node, but + // in that case, something about the graph upsets dagre.layout()'s + // longestPath algorithm (was getting errors due to an undefined). + let structuralNodeName = getBridgeNodeName(inbound, nodeName, 'STRUCTURAL_TARGET'); + let structuralRenderInfo = coreGraph.node(structuralNodeName); + if (!structuralRenderInfo) { + let bridgeNode: BridgeNode = { + // Important Node properties. + name: structuralNodeName, + type: NodeType.BRIDGE, + // Unimportant Node properties. + isGroupNode: false, + cardinality: 1, + parentNode: null!, + stats: null!, + include: InclusionType.UNSPECIFIED, + // BridgeNode properties. + inbound: inbound, + inputData: {}, + outputData: {}, + suggestions: {}, + nodeAttributes: {}, + }; + structuralRenderInfo = new RenderNodeInfo(bridgeNode); + structuralRenderInfo.structural = true; + this.index[structuralNodeName] = structuralRenderInfo as any; + coreGraph.setNode(structuralNodeName, structuralRenderInfo); + bridgeContainerInfo.node.cardinality++; + coreGraph.setParent(structuralNodeName, bridgeContainerName); + } + // Create the structural Metaedge and insert it. + let structuralMetaedgeInfo = new RenderMetaedgeInfo(null!); + structuralMetaedgeInfo.structural = true; + structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout. + inbound + ? coreGraph.setEdge(structuralNodeName, childName, structuralMetaedgeInfo) + : coreGraph.setEdge(childName, structuralNodeName, structuralMetaedgeInfo); + }); + }); + } + checkSubhierarchy(nodeName: string): boolean { + return nodeName in this.hasSubhierarchy; + } + removeSubhierarchy(nodeName: string): void { + console.log(this.index[nodeName]); + } + /** + * This method builds subhierarchies for function calls that are needed for + * rendering edges in the current subhierarchy being built. + * + * When building subhierarchies for a metagraph M, the subhierarchies of + * metanodes containing endpoint nodes for edges within metagraph M must + * already be built. Otherwise, bridge edges will be missing from the graph. + */ + // private buildSubhierarchiesForNeededFunctions(metagraph: graphlib.Graph) { + // _.each(metagraph.edges(), (edgeObj) => { + // let metaedge = metagraph.edge(edgeObj); + // let renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge as any); + // _.forEach(renderMetaedgeInfo.metaedge.baseEdgeList, (baseEdge) => { + // const sourcePathList = baseEdge.v!.split(tf_graph.NAMESPACE_DELIM); + // for (let i = sourcePathList.length; i >= 0; i--) { + // const fromBeginningPathList = sourcePathList.slice(0, i); + // const node = this.hierarchy.node( + // fromBeginningPathList.join(tf_graph.NAMESPACE_DELIM) + // ); + // if (node) { + // if ( + // node.type === NodeType.OP && + // this.hierarchy.libraryFunctions[(node as OpNode).op] + // ) { + // for (let j = 1; j < fromBeginningPathList.length; j++) { + // // Expand all hierarchies including the parent. + // const currentNodeName = fromBeginningPathList + // .slice(0, j) + // .join(tf_graph.NAMESPACE_DELIM); + // if (!currentNodeName) { + // continue; + // } + // // Build the hierarchy for this current level. + // this.buildSubhierarchy(currentNodeName); + // } + // } + // // No need to analyze the other higher hierarchies. + // break; + // } + // } + // }); + // }); + // } +} +/** + * A class for rendering annotation object which contains label + * about the node embedded as annotation, type of annotation and the location + * of both the annotation's node and edge. + * + * Annotation objects include embedded constants, embedded summary, and + * edge shortcuts. + */ +export class Annotation { + node: Node; + renderNodeInfo: RenderNodeInfo; + renderMetaedgeInfo: RenderMetaedgeInfo; + annotationType: AnnotationType; + /** + * Center position of annotation relative to the host + * node's center x. + */ + dx: number; + /** + * Center position of annotation relative to the host + * node's center y. + */ + dy: number; + width: number; + height: number; + /** + * The names of nodes on either side of this edge. + */ + v: string; + w: string; + /** + * A flag whether it is an in-annotation (if true) or + * out-annotation (if false). + */ + isIn: boolean; + /** Label horizontal offset from the end of the node shape */ + labelOffset: number; + /** + * Array of points for edges from the annotation to its host + * node. Each point contains the point location, relative to + * the host node's center. + */ + points: { + dx: number; + dy: number; + }[]; + /** + * Creates a new Annotation. + * + * @param node The underlying node this annotation points to. + * @param renderNodeInfo The render information for the underlying node + * this annotation points to. This can be null if the annotation + * denotes an embedding (constant, summary), in which case we + * use the node property. + * @param renderMetaedgeInfo The render information for the edge associated + * with the annotation. + * @param type The type of the annotation. + * @param isIn True if it is an in-annotation. False if it is an + * out-annotation. + */ + constructor( + node: Node, + renderNodeInfo: RenderNodeInfo, + renderMetaedgeInfo: RenderMetaedgeInfo, + type: AnnotationType, + isIn: boolean, + ) { + this.node = node; + this.renderNodeInfo = renderNodeInfo; + this.renderMetaedgeInfo = renderMetaedgeInfo; + this.annotationType = type; + // Properties specified by layout + this.dx = 0; + this.dy = 0; + this.width = 0; + this.height = 0; + // Properties needed for generating an ID for the edge's path element if + // this annotation is associated with a metaedge. + if (renderMetaedgeInfo && renderMetaedgeInfo.metaedge) { + this.v = renderMetaedgeInfo.metaedge.v!; + this.w = renderMetaedgeInfo.metaedge.w!; + } + this.isIn = isIn; + this.points = []; + } +} +export enum AnnotationType { + SHORTCUT, + CONSTANT, + SUMMARY, + ELLIPSIS, +} +/** + * Manages a list of annotations. Two will be used for each + * RenderNodeInfo, one for in annotations and one for out annotations. + */ +export class AnnotationList { + /** + * List of visually drawable annotations, may include an ellipses annotation + * if the number added exceeds the number specified by maxAnnotations. + */ + list: Annotation[]; + /** + * Set of nodes which have been added as annotations to this list, so we can + * prevent duplicates. + */ + nodeNames: { + [nodeName: string]: boolean; + }; + constructor() { + this.list = []; + this.nodeNames = {}; + } + /** + * Append an annotation to the list, or a stand-in ellipsis annotation instead + * if this would make it too many. + */ + push(annotation: Annotation): void { + if (annotation.node.name in this.nodeNames) { + return; // Skip duplicate annotation. + } + this.nodeNames[annotation.node.name] = true; + if (this.list.length < PARAMS.maxAnnotations) { + this.list.push(annotation); + return; + } + let lastAnnotation = this.list[this.list.length - 1]; + if (lastAnnotation.annotationType === AnnotationType.ELLIPSIS) { + let ellipsisNode = lastAnnotation.node; + ellipsisNode.setNumMoreNodes(++ellipsisNode.numMoreNodes); + return; + } + let ellipsisNode = new tf_graph.EllipsisNodeImpl(1); + this.list.push( + new Annotation(ellipsisNode, new RenderNodeInfo(ellipsisNode), null!, AnnotationType.ELLIPSIS, annotation.isIn), + ); + } +} +/** + * Contains rendering information about a node in the hierarchical graph. + */ +export class RenderNodeInfo { + /** Reference to the original underlying Node from the hierarchical graph. */ + node: Node; + /** Whether the node is expanded or not. */ + expanded: boolean; + /** + * List of rendering information about in-annotations like constants and + * shortcuts to high-degree nodes. + */ + inAnnotations: AnnotationList; + /** + * List of rendering information about out-annotations (e.g. summary nodes) + */ + outAnnotations: AnnotationList; + // --- Params specified by layout --- // + /** Center x position */ + x: number; + /** Center y position */ + y: number; + /** + * Total width of the node's shape, including in- and out-annotations. This + * property is used by dagre to layout the graph. + */ + width: number; + /** + * Total height of the node's shape, including in- and out-annotations. This + * property is used by dagre to layout the graph. + */ + height: number; + /** + * Size of the main box of the node, excluding in- and out-annotations. This + * property is used to draw the rectangle/ellipse shape denoting the node. + */ + coreBox: { + width: number; + height: number; + }; + /** Width of the bounding box for all in-annotations. */ + inboxWidth: number; + /** Width of the bounding box for all out-annotations. */ + outboxWidth: number; + /** + * Whether the node should be excluded from the scene. + * This is only used when there are too many items in a series so we only + * want to include top N ones. + */ + // TODO: Now that series rendering is non-recursive, remove this and + // all its uses from the code base. + excluded: boolean; + // --- Params used in drawing the bridge paths --- // + /** + * All bridge nodes are meant to be invisible, but whereas most represent a + * relationship from the underlying graph hierarchy, some exist solely for + * layout reasons. Specifically, those bridge nodes which have only structural + * rendering metaedges. + */ + structural: boolean; + // --- Params for the size of the node box --- // + /** Label vertical offset from the center of node shape */ + labelOffset: number; + /** Rectangle radius (for making rounded rectangle) */ + radius: number; + // --- Params for expanded node --- // + /** Label height for expanded node. */ + labelHeight: number; + // Paddings between inner subscene and the border of the expanded node. + paddingTop: number; + paddingLeft: number; + paddingRight: number; + paddingBottom: number; + /** + * Whether a node is extracted as source-like (having high out-degree or + * matching predefined in-extract pattern.) + */ + isInExtract: boolean; + /** + * Whether a node is extracted as sink-like (having high in-degree or matching + * predefined out-extract pattern.) + */ + isOutExtract: boolean; + /** + * Whether a node represents a function template within the library, in which + * case it should be rendered in a special scene group. + */ + isLibraryFunction: boolean; + /** + * List of (color, proportion) tuples based on the proportion of devices of + * its children. If this node is an op node, this list will have only one + * color with proportion 1.0. + */ + deviceColors: Array<{ + color: string; + proportion: number; + }>; + /** + * List of (color, proportion) tuples based on the proportion of xlaClusters + * of its children. If this node is an op node, this list will have only one + * color with proportion 1.0. + */ + xlaClusterColors: Array<{ + color: string; + proportion: number; + }>; + /** + * List of (color, proportion) tuples based on the proportion of compatible + * nodes of its children. If this node is an op node, this list will have only + * one color with proportion 1.0. + */ + compatibilityColors: Array<{ + color: string; + proportion: number; + }>; + /** + * Color according to the memory usage of this node. + */ + memoryColor: string; + /** + * Color according to the compute time of this node. + */ + computeTimeColor: string; + /** + * Whether this node is faded out. Used when displaying stats. + */ + isFadedOut: boolean; + /** + * The name string used to label the node in the graph. + */ + displayName: string; + constructor(node: Node) { + this.node = node; + this.expanded = false; + this.inAnnotations = new AnnotationList(); + this.outAnnotations = new AnnotationList(); + // Params specified by layout + this.x = 0; + this.y = 0; + this.width = 0; + this.height = 0; + this.inboxWidth = 0; + this.outboxWidth = 0; + this.excluded = false; + // Params for bridge paths. + this.structural = false; + // Params for node box. + this.labelOffset = 0; + this.radius = 0; + // Params for expanded node + this.labelHeight = 0; + this.paddingTop = 0; + this.paddingLeft = 0; + this.paddingRight = 0; + this.paddingBottom = 0; + this.isInExtract = false; + this.isOutExtract = false; + this.coreBox = { width: 0, height: 0 }; + // By default, we don't fade nodes out. Default to false for safety. + this.isFadedOut = false; + // Only use the portion beyond the prefix as the display name. + if (node.name.startsWith('B___') && node.parentNode.name === tf_graph.ROOT_NAME) { + this.displayName = '标杆'; + } else { + const nameList = node.name.split('.'); + if (nameList.length > 3) { + const secondLastItem = nameList[nameList.length - 2]; + nameList.splice(nameList.length - 2, 1); + nameList.splice(2, 0, secondLastItem); + this.displayName = nameList.slice(1, nameList.length - 1).join('.'); + } else if (node.name.startsWith('B___') || node.name.startsWith('N___')) { + this.displayName = node.name.slice(4); + } else { + this.displayName = node.name; + } + } + + if (node.type === (NodeType.META || NodeType.API_LIST) && (node as Metanode).associatedFunction) { + // Function names are suffixed with a length-8 hexadecimal string + // followed by an optional number. We remove that suffix because + // the user did not generate that suffix. That suffix merely + // serves to differentiate between functions with different + // signatures but the same name otherwise. + // Furthermore, we remove the prefix that merely ascertains this + // node as a function definition. There is no reason for the user + // to see that in the graph, as the node would already be within + // the functions scene group. + const match = this.displayName.match(nodeDisplayNameRegex); + if (match) { + // The display name had been successfully extracted. This is the most + // common scenario. + this.displayName = match[1]; + } else if (_.startsWith(this.displayName, tf_graph.FUNCTION_LIBRARY_NODE_PREFIX)) { + // The string does not match the usual pattern for how functions are + // named. Just use the entire second portion of the string as the name + // if we can successfully remove the prefix. + this.displayName = this.displayName.substring(tf_graph.FUNCTION_LIBRARY_NODE_PREFIX.length); + } + } + } + isInCore(): boolean { + return !this.isInExtract && !this.isOutExtract && !this.isLibraryFunction; + } +} +/** + * Contains rendering information about a Metaedge from the underlying + * hierarchical graph. It may be from either a metagraph or a bridgegraph. + */ +export class RenderMetaedgeInfo { + /** + * Reference to the original underlying Metaedge from the hierarchical graph, + * if any. This will be null for the edges which connect OpNodes to their + * embeddings, for example. + */ + metaedge: Metaedge; + /** + * Reference to the adjoining RenderMetaedgeInfo from the parent's + * coreGraph. This is used during layout to determine the point at which this + * edge should touch the node's bounding box. This property will be null for + * edges which terminate at a node on both ends (all non-bridge edges). + */ + adjoiningMetaedge: RenderMetaedgeInfo; + /** + * Most of the time, a RenderMetaedgeInfo object represents a real + * edge between nodes in the underlying graph structure. But sometimes, an + * edge only exists for layout purposes. These structural edges are added + * during buildSubhierarchy() to force dagre.layout() to put bridge nodes + * at the ends of the flow. + * @see buildSubhierarchy() + */ + structural: boolean; + /** + * Weight of the edge, used by dagre when deciding how important an edge is. + * Edges with higher weight are made shorter and straighter. The default + * dagre uses is 1. + */ + weight: number; + /** + * X and Y coordinate pairs of the points in the path of the edge. + * @see tf_graph.node.subsceneAdjustPaths + */ + points: Point[]; + /** + * D3 selection of the group containing the path that displays this edge. + */ + edgeGroup: d3.Selection; + /** Id of the used as a start-marker for the edge path. */ + startMarkerId: string; + /** Id of the used as an end-marker for the edge path. */ + endMarkerId: string; + /** + * Whether this edge is faded out. Used for fading out unused edges when + * displaying run statistics. + */ + isFadedOut: boolean; + constructor(metaedge: Metaedge) { + this.metaedge = metaedge; + this.adjoiningMetaedge = null!; + this.structural = false; + this.weight = 1; + this.isFadedOut = false; + } +} +function addInAnnotation( + node: RenderNodeInfo, + predecessor: Node, + predecessorRenderInfo: RenderNodeInfo, + edge: RenderMetaedgeInfo, + type: AnnotationType, +): void { + let annotation = new Annotation(predecessor, predecessorRenderInfo, edge, type, true); + node.inAnnotations.push(annotation); +} +function addOutAnnotation( + node: RenderNodeInfo, + successor: Node, + successorRenderInfo: RenderNodeInfo, + edge: RenderMetaedgeInfo, + type: AnnotationType, +): void { + let annotation = new Annotation(successor, successorRenderInfo, edge, type, false); + node.outAnnotations.push(annotation); +} +function setGraphDepth(graph: graphlib.Graph, depth: number) { + _.each(graph.nodes(), (nodeName) => { + let child = graph.node(nodeName); + child.expanded = depth > 1; // set all child of depth 1 to collapsed + if (depth > 0) { + switch (child.node.type) { + case NodeType.META: + case NodeType.API_LIST: + case NodeType.SERIES: + setGroupNodeDepth(child, depth - 1); + break; + // Do nothing for leaf + } + } + }); +} +export class RenderGroupNodeInfo extends RenderNodeInfo { + override node: GroupNode; + /** + * The core graph is derived from the underlying node's metagraph, minus + * the extracted source-like and sink-like nodes. + */ + coreGraph: graphlib.Graph; + /** Size of the bounding box for a metanode's isolated in-extract children. */ + inExtractBox: { + width: number; + height: number; + }; + /** + * Size of the bounding box for a metanode's isolated out-extract children. + */ + outExtractBox: { + width: number; + height: number; + }; + /** Size of the bounding box for the function library. */ + libraryFunctionsBox: { + width: number; + height: number; + }; + /** Array of isolated in-extract nodes. */ + isolatedInExtract: RenderNodeInfo[]; + /** Array of isolated out-extract nodes. */ + isolatedOutExtract: RenderNodeInfo[]; + /** Array of nodes to show in the function library scene group. */ + libraryFunctionsExtract: RenderNodeInfo[]; + constructor(groupNode: GroupNode, graphOptions: tf_graph.LabeledGraphOptions) { + super(groupNode); + let metagraph = groupNode.metagraph; + let gl = metagraph.graph() as any; + this.coreGraph = createGraph(gl.name, GraphType.CORE, graphOptions); + this.inExtractBox = { width: 0, height: 0 }; + this.outExtractBox = { width: 0, height: 0 }; + this.libraryFunctionsBox = { width: 0, height: 0 }; + this.isolatedInExtract = []; + this.isolatedOutExtract = []; + this.libraryFunctionsExtract = []; + } +} +function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo, depth: number): void { + if (renderInfo.coreGraph) { + setGraphDepth(renderInfo.coreGraph, depth); + } +} +/** + * Remove an edge from the graph and add annotations to both ends of the edge. + * + * @param The core graph. + * @param v Source name. + * @param w Sink name. + */ +function createShortcut(graph: graphlib.Graph, v: string, w: string) { + let src = graph.node(v) as any; + let sink = graph.node(w) as any; + let edge = graph.edge(v, w) as any; + // If either of the nodes is explicitly included in the main graph and + // both nodes are in the main graph then do not create the shortcut + // and instead keep the real edge. + if ( + (src.node.include === InclusionType.INCLUDE || sink.node.include === InclusionType.INCLUDE) && + src.node.include !== InclusionType.EXCLUDE && + sink.node.include !== InclusionType.EXCLUDE + ) { + return; + } + // Add each annotation. + addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT); + addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT); + // Remove the edge from the core graph. + graph.removeEdge(v, w); +} +/** + * Remove edges from a node, and set its isOutExtract property to true, + * and remove the node and move it to isolatedOutExtract. + * + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only extract all in-edges. + */ +function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string, forceDetach?: boolean) { + let graph = renderNode.coreGraph; + let child = graph.node(n) as any; + child.isOutExtract = true; + _.each(graph.predecessors(n), (p, index) => { + createShortcut(graph, p, n); + }); + if (PARAMS.detachAllEdgesForHighDegree || forceDetach) { + _.each(graph.successors(n), (s, index) => { + createShortcut(graph, n, s); + }); + } + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n)?.length === 0) { + child.node.include = InclusionType.EXCLUDE; + renderNode.isolatedOutExtract.push(child); + graph.removeNode(n); + } +} +/** + * Remove edges from a node, set its isInExtract property to true, + * and remove the node and move it to isolatedInExtract. + * + * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its + * edges. Otherwise, only remove all out-edges. + */ +export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string, forceDetach?: boolean) { + let graph = renderNode.coreGraph; + let child = graph.node(n) as any; + child.isInExtract = true; + _.each(graph.successors(n), (s, index) => { + createShortcut(graph, n, s); + }); + if (PARAMS.detachAllEdgesForHighDegree || forceDetach) { + _.each(graph.predecessors(n), (p, index) => { + createShortcut(graph, p, n); + }); + } + // Remove the node from the core graph if it no longer has neighbors. + if (graph.neighbors(n)?.length === 0) { + child.node.include = InclusionType.EXCLUDE; + renderNode.isolatedInExtract.push(child); + graph.removeNode(n); + } +} +/** + * Check whether the node's type is a member of the given list of types. + * + * @param node Node. + * @param types List of type to match. + */ +function hasTypeIn(node: Node, types: string[]): boolean { + if (node.type === NodeType.OP) { + for (let i = 0; i < types.length; i++) { + if ((node).op === types[i]) { + return true; + } + } + } else if (node.type === NodeType.META) { + let rootOpNode = (node).getRootOp(); + if (rootOpNode) { + for (let i = 0; i < types.length; i++) { + if (rootOpNode.op === types[i]) { + return true; + } + } + } + } + return false; +} +/** Move nodes that are specified to be excluded out of the core graph. */ +function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), (n) => { + let renderInfo = graph.node(n); + if (renderInfo.node.include === InclusionType.EXCLUDE && !n.startsWith(tf_graph.FUNCTION_LIBRARY_NODE_PREFIX)) { + // Move the node if the node is excluded and not part of the library + // function scene group, which contains nodes that do not represent ops in + // the graph and should thus never have its nodes added to the core graph. + if (renderNode.coreGraph.outEdges(n)?.length! > renderNode.coreGraph.inEdges(n)?.length!) { + makeOutExtract(renderNode, n, true); + } else { + makeInExtract(renderNode, n, true); + } + } + }); +} +/** Remove edges from pre-defined out-extract patterns */ +function extractPredefinedSink(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), (n) => { + let renderInfo = graph.node(n); + if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { + return; + } + if (hasTypeIn(renderInfo.node, PARAMS.outExtractTypes)) { + makeOutExtract(renderNode, n); + } + }); +} +/** Remove edges from pre-defined in-extract patterns */ +function extractPredefinedSource(renderNode) { + let graph = renderNode.coreGraph; + _.each(graph.nodes(), (n) => { + let renderInfo = graph.node(n); + if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { + return; + } + if (hasTypeIn(renderInfo.node, PARAMS.inExtractTypes)) { + makeInExtract(renderNode, n); + } + }); +} +/** Extract nodes deemed to have either high in-degree or high out-degree. */ +function extractHighInOrOutDegree(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + // Create mappings from node to in and out degrees. Count the number of valid + // nodes along the way. + let nodeToInDegree = {}; + let nodeToOutDegree = {}; + let validNodeCount = 0; + _.each(graph.nodes(), (currentNode) => { + if (graph.node(currentNode).node.include !== InclusionType.UNSPECIFIED) { + // This node is not included in the first place. + return; + } + // Count the in and out degrees based on only regular edges, unless there + // are no regular edges, in which case use the number of control edges. + // This is done so that control edges don't affect if nodes are extracted + // from the core graph, unless the node is only used for control. + let inDegree = _.reduce( + graph.predecessors(currentNode), + (inDegree, pred) => { + let edgeObj = _.find(graph.edges(), (item) => item.v === pred && item.w === currentNode); + let metaedge = edgeObj ? graph.edge(edgeObj).metaedge : {}; + return inDegree + (metaedge.numRegularEdges ? 1 : 0); + }, + 0, + ); + if (inDegree === 0 && graph.predecessors(currentNode)?.length! > 0) { + inDegree = graph.predecessors(currentNode)?.length!; + } + let outDegree = _.reduce( + graph.successors(currentNode), + (outDegree, succ) => { + let edgeObj = _.find(graph.edges(), (item) => item.v === currentNode && item.w === succ); + let metaedge = edgeObj ? graph.edge(edgeObj).metaedge : {}; + return outDegree + (metaedge.numRegularEdges ? 1 : 0); + }, + 0, + ); + if (outDegree === 0 && graph.successors(currentNode)?.length! > 0) { + outDegree = graph.successors(currentNode)?.length; + } + // Store the in and out degrees of this node to avoid recomputing. + nodeToInDegree[currentNode] = inDegree; + nodeToOutDegree[currentNode] = outDegree; + validNodeCount++; + }); + if (validNodeCount < PARAMS.minNodeCountForExtraction) { + // This graph has few nodes. Do not extract any nodes. + return; + } + // We only extract if the node has a min in or out degree greater than this. + let minUpperBound = PARAMS.minDegreeForExtraction - 1; + // Mark for extraction nodes with in-degree > Q3 + (Q3 - Q1). + let q3Index = Math.round(validNodeCount * 0.75); + let q1Index = Math.round(validNodeCount * 0.25); + let sortedByInDegree = Object.keys(nodeToInDegree).sort((node0, node1) => { + return nodeToInDegree[node0] - nodeToInDegree[node1]; + }); + let inDegreeQ3 = nodeToInDegree[sortedByInDegree[q3Index]]; + let inDegreeQ1 = nodeToInDegree[sortedByInDegree[q1Index]]; + let inDegreeUpperBound = inDegreeQ3 + inDegreeQ3 - inDegreeQ1; + // Only extract if the upper bound is high enough. + inDegreeUpperBound = Math.max(inDegreeUpperBound, minUpperBound); + for (let i = validNodeCount - 1; nodeToInDegree[sortedByInDegree[i]] > inDegreeUpperBound; i--) { + // Extract a high in-degree node. + makeInExtract(renderNode, sortedByInDegree[i]); + } + // Mark for extraction nodes with out-degree > Q3 + (Q3 - Q1) * 4. + let sortedByOutDegree = Object.keys(nodeToOutDegree).sort((node0, node1) => { + return nodeToOutDegree[node0] - nodeToOutDegree[node1]; + }); + let outDegreeQ3 = nodeToOutDegree[sortedByOutDegree[q3Index]]; + let outDegreeQ1 = nodeToOutDegree[sortedByOutDegree[q1Index]]; + // The upper bound for extracting out-degree nodes is higher than that for + // extracting in-degree ones (Note the "* 4") because, in practice, some + // graphs look worse with a smaller out-degree bound. For instance, a smaller + // out-degree bound removes the convolution nodes from cifar 10 train's graph. + let outDegreeUpperBound = outDegreeQ3 + (outDegreeQ3 - outDegreeQ1) * 4; + // Only extract if the upper bound is high enough. + outDegreeUpperBound = Math.max(outDegreeUpperBound, minUpperBound); + for (let i = validNodeCount - 1; nodeToOutDegree[sortedByOutDegree[i]] > outDegreeUpperBound; i--) { + let node = graph.node(sortedByOutDegree[i]); + if (!node || node.isInExtract) { + // This node has already been extracted due to high in-degree. It might + // have been removed from the graph in general (during in-degree + // extraction) due to a lack of neighbors. Do not extract this node twice. + continue; + } + // Extract a high out-degree node that has not already been extracted. + makeOutExtract(renderNode, sortedByOutDegree[i]); + } +} +/** Remove control edges from nodes that have too many control edges */ +function removeControlEdges(renderNode: RenderGroupNodeInfo) { + let graph = renderNode.coreGraph; + // Collect control edges into a map by node name. + let map = <{ [nodeName: string]: any[] }>{}; + _.each(graph.edges(), (e) => { + if (!graph.edge(e).metaedge.numRegularEdges) { + (map[e.v] = map[e.v] || []).push(e); + (map[e.w] = map[e.w] || []).push(e); + } + }); + // For each node with too many control edges, turn them into annotations. + _.each(map, (edges, nodeName) => { + if (edges.length > PARAMS.maxControlDegree) { + _.each(edges, (e) => createShortcut(graph, e.v, e.w)); + } + }); +} +/** + * Given an integer, picks a hue that is far apart from other colors. + * The formula for picking color that avoid collision is: + * hue = (color range * golden ratio * index) % color range + */ +export function mapIndexToHue(id: number): number { + let GOLDEN_RATIO = 1.61803398875; + // Hue of 0 is reserved for the gray nodes. + let MIN_HUE = 1; + let MAX_HUE = 359; + let COLOR_RANGE = MAX_HUE - MIN_HUE; + return MIN_HUE + ((COLOR_RANGE * GOLDEN_RATIO * id) % COLOR_RANGE); +} +/** + * Remove edges and add to annotation instead. + * + * For root node, consider predefined types for source and sink. + * We do not extract predefined type from non-root so that Variables and the + * sgd node (op type = 'NoOp') do not get extract from inside own group. + * + * The order of extraction is important here as swapping the order can totally + * screw up the graph layout. + * + * @param {Render.Node} renderNode Node to manipulate. + * nodes from the main graph. If false, only exclude predefined nodes. + */ +/** + * Expands nodes in the graph until the desired node is visible. + * + * @param scene The scene polymer component. + * @param renderHierarchy The render hierarchy. + * @param tensorName The name of a tensor. + * @return A string that is the name of the node representing the given tensor. + * Note that the original tensor name might differ from this returned node + * name. Specifically, for instance, the tensor name usually ends with an + * output slot index (such as :0), while the node name lacks that suffix. + */ +export function expandUntilNodeIsShown(scene, renderHierarchy, tensorName: string) { + const splitTensorName = tensorName.split('/'); + // Graph names do not take into account the output slot. Strip it. + const lastNodeNameMatch = splitTensorName[splitTensorName.length - 1].match(/(.*):\w+/); + if (lastNodeNameMatch?.length === 2) { + splitTensorName[splitTensorName.length - 1] = lastNodeNameMatch?.[1]; + } + let nodeName = splitTensorName[0]; + let renderNode = renderHierarchy.getRenderNodeByName(nodeName); + for (let i = 1; i < splitTensorName.length; i++) { + // Op nodes are not expandable. + if (renderNode.node.type === tf_graph.NodeType.OP) { + break; + } + renderHierarchy.buildSubhierarchy(nodeName); + renderNode.expanded = true; + scene.setNodeExpanded(renderNode); + nodeName += '/' + splitTensorName[i]; + renderNode = renderHierarchy.getRenderNodeByName(nodeName); + } + return renderNode.node.name; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/scene.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/scene.ts new file mode 100644 index 0000000000000000000000000000000000000000..44dbc0d38e56bfa92d987f5be88d35d9348dfd35 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/scene.ts @@ -0,0 +1,578 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; +import * as PolymerDom from '../polymer/dom'; +import { Class as _Class, selectChild as _selectChild, SVG_NAMESPACE } from './common'; +import { NodeType, OpNode } from './graph'; +import * as layout from './layout'; +import * as render from './render'; + +export const selectChild = _selectChild; +export const Class = _Class; + +/** + * The dimensions of the minimap including padding and margin. + */ +const MINIMAP_BOX_WIDTH = 320; +const MINIMAP_BOX_HEIGHT = 150; +/** + * A health pill encapsulates an overview of tensor element values. The value + * field is a list of 12 numbers that shed light on the status of the tensor. + * Visualized in health pills are the 3rd through 8th (inclusive) numbers of + * health pill values. Those 6 numbers are counts of tensor elements that fall + * under -Inf, negative, 0, positive, +Inf, NaN (in that order). + * + * Please keep this interface consistent with HealthPillDatum within + * backend.ts. + */ +export interface HealthPill { + device_name: string; + node_name: string; + output_slot: number; + dtype: string; + shape: number[]; + value: number[]; + wall_time: number; + step: number; +} +interface HealthPillNumericStats { + min: number; + max: number; + mean: number; + stddev: number; +} +/** + * Encapsulates how to render a single entry in a health pill. Each entry + * corresponds to a category of tensor element values. + */ +export interface HealthPillEntry { + background_color: string; + label: string; +} +export let healthPillEntries: HealthPillEntry[] = [ + { + background_color: '#CC2F2C', + label: 'NaN', + }, + { + background_color: '#FF8D00', + label: '-∞', + }, + { + background_color: '#EAEAEA', + label: '-', + }, + { + background_color: '#A5A5A5', + label: '0', + }, + { + background_color: '#262626', + label: '+', + }, + { + background_color: '#003ED4', + label: '+∞', + }, +]; +/** + * Helper method for fitting the graph in the svg view. + * + * @param svg The main svg. + * @param zoomG The svg group used for panning and zooming. + * @param d3zoom The zoom behavior. + * @param callback Called when the fitting is done. + */ +export function fit(svg, zoomG, d3zoom, callback) { + let svgRect = svg.getBoundingClientRect(); + let sceneSize: DOMRect | null = null; + try { + sceneSize = zoomG.getBBox(); + if (sceneSize?.width! === 0) { + // There is no scene anymore. We have been detached from the dom. + return; + } + } catch (e) { + // Firefox produced NS_ERROR_FAILURE if we have been + // detached from the dom. + return; + } + let scale = 0.9 * Math.min(svgRect.width / (sceneSize?.width ?? 1), svgRect.height / (sceneSize?.height ?? 1), 2); + let params = layout.PARAMS.graph; + const transform = d3.zoomIdentity.scale(scale).translate(params.padding.paddingLeft, params.padding.paddingTop); + d3.select(svg) + .transition() + .duration(500) + .call(d3zoom.transform, transform) + .on('end.fitted', () => { + // Remove the listener for the zoomend event, + // so we don't get called at the end of regular zoom events, + // just those that fit the graph to screen. + d3zoom.on('end.fitted', null); + callback(); + }); +} +/** + * Helper method for panning the graph to center on the provided node, + * if the node is currently off-screen. + * + * @param nodeName The node to center the graph on + * @param svg The root SVG element for the graph + * @param zoomG The svg group used for panning and zooming. + * @param d3zoom The zoom behavior. + * @return True if the graph had to be panned to display the + * provided node. + */ +export function panToNode(nodeName: String, svg, zoomG, d3zoom): boolean { + const node = d3.select(svg).select(`[data-name="${nodeName}"]`).node(); + if (!node) { + console.warn(`panToNode() failed for node name "${nodeName}"`); + return false; + } + // Check if the selected node is off-screen in either + // X or Y dimension in either direction. + let nodeBox = node.getBBox(); + let nodeCtm = node.getScreenCTM(); + let pointTL = svg.createSVGPoint(); + let pointBR = svg.createSVGPoint(); + pointTL.x = nodeBox.x; + pointTL.y = nodeBox.y; + pointBR.x = nodeBox.x + nodeBox.width; + pointBR.y = nodeBox.y + nodeBox.height; + pointTL = pointTL.matrixTransform(nodeCtm); + pointBR = pointBR.matrixTransform(nodeCtm); + let isOutsideOfBounds = (start, end, lowerBound, upperBound) => { + // Return if even a part of the interval is out of bounds. + return !(start > lowerBound && end < upperBound); + }; + let svgRect = svg.getBoundingClientRect(); + // Subtract to make sure that the node is not hidden behind the minimap. + const horizontalBound = svgRect.left + svgRect.width - MINIMAP_BOX_WIDTH; + const verticalBound = svgRect.top + svgRect.height - MINIMAP_BOX_HEIGHT; + if ( + isOutsideOfBounds(pointTL.x, pointBR.x, svgRect.left, horizontalBound) || + isOutsideOfBounds(pointTL.y, pointBR.y, svgRect.top, verticalBound) + ) { + // Determine the amount to translate the graph in both X and Y dimensions in + // order to center the selected node. This takes into account the position + // of the node, the size of the svg scene, the amount the scene has been + // scaled by through zooming, and any previous transforms already performed + // by this logic. + let centerX = (pointTL.x + pointBR.x) / 2; + let centerY = (pointTL.y + pointBR.y) / 2; + let dx = svgRect.left + svgRect.width / 2 - centerX; + let dy = svgRect.top + svgRect.height / 2 - centerY; + // We translate by this amount. We divide the X and Y translations by the + // scale to undo how translateBy scales the translations (in d3 v4). + const svgTransform = d3.zoomTransform(svg); + d3.select(svg) + .transition() + .duration(500) + .call(d3zoom.translateBy, dx / svgTransform.k, dy / svgTransform.k); + return true; + } + return false; +} +/** + * Given a scene's svg group, set g.in-extract, g.coreGraph, g.out-extract svg + * groups' position relative to the scene. + * + * @param sceneGroup + * @param renderNode render node of a metanode or series node. + */ +export function position(sceneGroup, renderNode: render.RenderGroupNodeInfo) { + // Translate scenes down by the label height so that when showing graphs in + // expanded metanodes, the graphs are below the labels. Do not shift them + // down for series nodes as series nodes don't have labels inside of their + // bounding boxes. + let yTranslate = renderNode.node.type === NodeType.SERIES ? 0 : layout.PARAMS.subscene.meta.labelHeight; + // core + translate(selectChild(sceneGroup, 'g', Class.Scene.CORE), 0, yTranslate); + // in-extract + let hasInExtract = renderNode.isolatedInExtract.length > 0; + let hasOutExtract = renderNode.isolatedOutExtract.length > 0; + let hasLibraryFunctions = renderNode.libraryFunctionsExtract.length > 0; + let offset = layout.PARAMS.subscene.meta.extractXOffset; + let auxWidth = 0; + if (hasInExtract) { + auxWidth += renderNode.outExtractBox.width; + } + if (hasOutExtract) { + auxWidth += renderNode.outExtractBox.width; + } + if (hasInExtract) { + let inExtractX = renderNode.coreBox.width; + if (auxWidth < layout.MIN_AUX_WIDTH) { + inExtractX = inExtractX - layout.MIN_AUX_WIDTH + renderNode.inExtractBox.width / 2; + } else { + inExtractX = + inExtractX - renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width - (hasOutExtract ? offset : 0); + } + inExtractX = inExtractX - renderNode.libraryFunctionsBox.width - (hasLibraryFunctions ? offset : 0); + translate(selectChild(sceneGroup, 'g', Class.Scene.INEXTRACT), inExtractX, yTranslate); + } + // out-extract + if (hasOutExtract) { + let outExtractX = renderNode.coreBox.width; + if (auxWidth < layout.MIN_AUX_WIDTH) { + outExtractX = outExtractX - layout.MIN_AUX_WIDTH + renderNode.outExtractBox.width / 2; + } else { + outExtractX -= renderNode.outExtractBox.width / 2; + } + outExtractX = outExtractX - renderNode.libraryFunctionsBox.width - (hasLibraryFunctions ? offset : 0); + translate(selectChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT), outExtractX, yTranslate); + } + if (hasLibraryFunctions) { + let libraryFunctionsExtractX = renderNode.coreBox.width - renderNode.libraryFunctionsBox.width / 2; + translate(selectChild(sceneGroup, 'g', Class.Scene.FUNCTION_LIBRARY), libraryFunctionsExtractX, yTranslate); + } +} +/** Adds a click listener to a group that fires a graph-select event */ +export function addGraphClickListener(graphGroup, sceneElement) { + d3.select(graphGroup).on('click', () => { + sceneElement.fire('graph-select'); + }); +} +/** Helper for adding transform: translate(x0, y0) */ +export function translate(selection, x0: number, y0: number) { + // If it is already placed on the screen, make it a transition. + if (selection.attr('transform') != null) { + selection = selection.transition('position'); + } + selection.attr('transform', 'translate(' + x0 + ',' + y0 + ')'); +} +/** + * Helper for setting position of a svg rect + * @param rect A d3 selection of rect(s) to set position of. + * @param cx Center x. + * @param cy Center x. + * @param width Width to set. + * @param height Height to set. + */ +export function positionRect(rect, cx: number, cy: number, width: number, height: number) { + rect + .transition() + .attr('x', cx - width / 2) + .attr('y', cy - height / 2) + .attr('width', width) + .attr('height', height); +} +/** + * Positions a triangle and sizes it. + * @param polygon polygon to set position of. + * @param cx Center x. + * @param cy Center y. + * @param width Width of bounding box for triangle. + * @param height Height of bounding box for triangle. + */ +export function positionTriangle(polygon, cx, cy, width, height) { + const halfHeight = height / 2; + const halfWidth = width / 2; + const points = [ + [cx, cy - halfHeight], + [cx + halfWidth, cy + halfHeight], + [cx - halfWidth, cy + halfHeight], + ]; + polygon.transition().attr('points', points.map((point) => point.join(',')).join(' ')); +} +/** + * Helper for setting position of a svg expand/collapse button + * @param button container group + * @param renderNode the render node of the group node to position + * the button on. + */ +export function positionButton(button, renderNode: render.RenderNodeInfo) { + let cx = layout.computeCXPositionOfNodeShape(renderNode); + // Position the button in the top-right corner of the group node, + // with space given the draw the button inside of the corner. + let width = renderNode.expanded ? renderNode.width : renderNode.coreBox.width; + let height = renderNode.expanded ? renderNode.height : renderNode.coreBox.height; + let x = cx + width / 2 - 6; + let y = renderNode.y - height / 2 + 6; + // For unexpanded series nodes, the button has special placement due + // to the unique visuals of this group node. + if (renderNode.node.type === NodeType.SERIES && !renderNode.expanded) { + x += 10; + y -= 2; + } + let translateStr = 'translate(' + x + ',' + y + ')'; + button.selectAll('path').transition().attr('transform', translateStr); + button.select('circle').transition().attr({ cx: x, cy: y, r: layout.PARAMS.nodeSize.meta.expandButtonRadius }); +} +/** + * Helper for setting position of a svg ellipse + * @param ellipse ellipse to set position of. + * @param cx Center x. + * @param cy Center x. + * @param width Width to set. + * @param height Height to set. + */ +export function positionEllipse(ellipse, cx: number, cy: number, width: number, height: number) { + ellipse + .transition() + .attr('cx', cx) + .attr('cy', cy) + .attr('rx', width / 2) + .attr('ry', height / 2); +} +/** + * @param {number} stat A stat for a health pill (such as mean or variance). + * @param {boolean} shouldRoundOnesDigit Whether to round this number to the + * ones digit. Useful for say int, uint, and bool output types. + * @return {string} A human-friendly string representation of that stat. + */ +export function humanizeHealthPillStat(stat, shouldRoundOnesDigit) { + if (shouldRoundOnesDigit) { + return stat.toFixed(0); + } + if (Math.abs(stat) >= 1) { + return stat.toFixed(1); + } + return stat.toExponential(1); +} +/** + * Get text content describing a health pill. + */ +function _getHealthPillTextContent( + healthPill: HealthPill, + totalCount: number, + elementsBreakdown: number[], + numericStats: HealthPillNumericStats, +) { + let text = 'Device: ' + healthPill.device_name + '\n'; + text += 'dtype: ' + healthPill.dtype + '\n'; + let shapeStr = '(scalar)'; + if (healthPill.shape.length > 0) { + shapeStr = '(' + healthPill.shape.join(',') + ')'; + } + text += '\nshape: ' + shapeStr + '\n\n'; + text += '#(elements): ' + totalCount + '\n'; + const breakdownItems: string[] = []; + for (let i = 0; i < elementsBreakdown.length; i++) { + if (elementsBreakdown[i] > 0) { + breakdownItems.push('#(' + healthPillEntries[i].label + '): ' + elementsBreakdown[i]); + } + } + text += breakdownItems.join(', ') + '\n\n'; + // In some cases (e.g., size-0 tensors; all elements are nan or inf) the + // min/max and mean/stddev stats are meaningless. + if (numericStats.max >= numericStats.min) { + text += 'min: ' + numericStats.min + ', max: ' + numericStats.max + '\n'; + text += 'mean: ' + numericStats.mean + ', stddev: ' + numericStats.stddev; + } + return text; +} +/** + * Renders a health pill for an op atop a node. + * nodeGroupElement: The SVG element in which to render. + * healthPill: A list of backend.HealthPill objects. + * nodeInfo: Info on the associated node. + * healthPillId: A unique numeric ID assigned to this health pill. + * healthPillWidth: Optional width of the health pill. + * healthPillHeight: Optional height of the health pill. + * healthPillYOffset: Optional y-offset of the health pill (that is, the + * color-coded region). + * textOffset: Optional value for the x-offset of the top text label + * relative to the left edge of the health pill. If not provided, will + * default to `healthPillWidth / 2`. + */ +export function addHealthPill( + nodeGroupElement: SVGElement, + healthPill: HealthPill, + nodeInfo: render.RenderNodeInfo, + healthPillId: number, + healthPillWidth = 60, + healthPillHeight = 10, + healthPillYOffset = 0, + textXOffset?: number, +) { + // Check if text already exists at location. + d3.select(nodeGroupElement.parentNode as any) + .selectAll('.health-pill') + .remove(); + if (!healthPill) { + return; + } + const lastHealthPillData = healthPill.value; + // For now, we only visualize the 6 values that summarize counts of tensor + // elements of various categories: -Inf, negative, 0, positive, Inf, and NaN. + const lastHealthPillElementsBreakdown = lastHealthPillData.slice(2, 8); + const nanCount = lastHealthPillElementsBreakdown[0]; + const negInfCount = lastHealthPillElementsBreakdown[1]; + const posInfCount = lastHealthPillElementsBreakdown[5]; + let totalCount = lastHealthPillData[1]; + const numericStats: HealthPillNumericStats = { + min: lastHealthPillData[8], + max: lastHealthPillData[9], + mean: lastHealthPillData[10], + stddev: Math.sqrt(lastHealthPillData[11]), + }; + if (healthPillWidth == null) { + healthPillWidth = 60; + } + if (healthPillHeight == null) { + healthPillHeight = 10; + } + if (healthPillYOffset == null) { + healthPillYOffset = 0; + } + if (nodeInfo != null && nodeInfo.node.type === NodeType.OP) { + // Use a smaller health pill for op nodes (rendered as smaller ellipses). + healthPillWidth /= 2; + healthPillHeight /= 2; + } + let healthPillGroup = document.createElementNS(SVG_NAMESPACE, 'g'); + healthPillGroup.classList.add('health-pill'); + // Define the gradient for the health pill. + let healthPillDefs = document.createElementNS(SVG_NAMESPACE, 'defs'); + healthPillGroup.appendChild(healthPillDefs); + let healthPillGradient = document.createElementNS(SVG_NAMESPACE, 'linearGradient'); + // Every element in a web page must have a unique ID. + const healthPillGradientId = 'health-pill-gradient-' + healthPillId; + healthPillGradient.setAttribute('id', healthPillGradientId); + let cumulativeCount = 0; + let previousOffset = '0%'; + for (let i = 0; i < lastHealthPillElementsBreakdown.length; i++) { + if (!lastHealthPillElementsBreakdown[i]) { + // Exclude empty categories. + continue; + } + cumulativeCount += lastHealthPillElementsBreakdown[i]; + // Create a color interval using 2 stop elements. + let stopElement0 = document.createElementNS(SVG_NAMESPACE, 'stop'); + stopElement0.setAttribute('offset', previousOffset); + stopElement0.setAttribute('stop-color', healthPillEntries[i].background_color); + healthPillGradient.appendChild(stopElement0); + let stopElement1 = document.createElementNS(SVG_NAMESPACE, 'stop'); + let percent = (cumulativeCount * 100) / totalCount + '%'; + stopElement1.setAttribute('offset', percent); + stopElement1.setAttribute('stop-color', healthPillEntries[i].background_color); + healthPillGradient.appendChild(stopElement1); + previousOffset = percent; + } + healthPillDefs.appendChild(healthPillGradient); + // Create the rectangle for the health pill. + let rect = document.createElementNS(SVG_NAMESPACE, 'rect'); + rect.setAttribute('fill', 'url(#' + healthPillGradientId + ')'); + rect.setAttribute('width', String(healthPillWidth)); + rect.setAttribute('height', String(healthPillHeight)); + rect.setAttribute('y', String(healthPillYOffset)); + healthPillGroup.appendChild(rect); + // Show a title with specific counts on hover. + let titleSvg = document.createElementNS(SVG_NAMESPACE, 'title'); + titleSvg.textContent = _getHealthPillTextContent( + healthPill, + totalCount, + lastHealthPillElementsBreakdown, + numericStats, + ); + healthPillGroup.appendChild(titleSvg); + // Center this health pill just right above the node for the op. + let shouldRoundOnesDigit = false; + if (nodeInfo != null) { + let healthPillX = nodeInfo.x - healthPillWidth / 2; + let healthPillY = nodeInfo.y - healthPillHeight - nodeInfo.height / 2 - 2; + if (nodeInfo.labelOffset < 0) { + // The label is positioned above the node. Do not occlude the label. + healthPillY += nodeInfo.labelOffset; + } + healthPillGroup.setAttribute('transform', 'translate(' + healthPillX + ', ' + healthPillY + ')'); + if ( + lastHealthPillElementsBreakdown[2] || + lastHealthPillElementsBreakdown[3] || + lastHealthPillElementsBreakdown[4] + ) { + // At least 1 "non-Inf and non-NaN" value exists (a -, 0, or + value). Show + // stats on tensor values. + // Determine if we should display the output range as integers. + let node = nodeInfo.node as OpNode; + let attributes = node.attr; + if (attributes && attributes.length) { + // Find the attribute for output type if there is one. + for (let i = 0; i < attributes.length; i++) { + if (attributes[i].key === 'T') { + // Note whether the output type is an integer. + let outputType = attributes[i].value['type']; + shouldRoundOnesDigit = outputType && /^DT_(BOOL|INT|UINT)/.test(outputType); + break; + } + } + } + } + } + let statsSvg = document.createElementNS(SVG_NAMESPACE, 'text'); + if (Number.isFinite(numericStats.min) && Number.isFinite(numericStats.max)) { + const minString = humanizeHealthPillStat(numericStats.min, shouldRoundOnesDigit); + const maxString = humanizeHealthPillStat(numericStats.max, shouldRoundOnesDigit); + if (totalCount > 1) { + statsSvg.textContent = minString + ' ~ ' + maxString; + } else { + statsSvg.textContent = minString; + } + if (nanCount > 0 || negInfCount > 0 || posInfCount > 0) { + statsSvg.textContent += ' ('; + const badValueStrings: string[] = []; + if (nanCount > 0) { + badValueStrings.push(`NaN×${nanCount}`); + } + if (negInfCount > 0) { + badValueStrings.push(`-∞×${negInfCount}`); + } + if (posInfCount > 0) { + badValueStrings.push(`+∞×${posInfCount}`); + } + statsSvg.textContent += badValueStrings.join('; ') + ')'; + } + } else { + statsSvg.textContent = '(No finite elements)'; + } + statsSvg.classList.add('health-pill-stats'); + if (textXOffset == null) { + textXOffset = healthPillWidth / 2; + } + statsSvg.setAttribute('x', String(textXOffset)); + statsSvg.setAttribute('y', String(healthPillYOffset - 2)); + healthPillGroup.appendChild(statsSvg); + (PolymerDom.dom(nodeGroupElement.parentNode) as any).appendChild(healthPillGroup); +} +/** + * Adds health pills (which visualize tensor summaries) to a graph group. + * @param svgRoot The root SVG element of the graph to add heath pills to. + * @param nodeNamesToHealthPills An object mapping node name to health pill. + * @param colors A list of colors to use. + */ +export function addHealthPills( + svgRoot: SVGElement, + nodeNamesToHealthPills: { + [key: string]: HealthPill[]; + }, + healthPillStepIndex: number, +) { + if (!nodeNamesToHealthPills) { + // No health pill information available. + return; + } + // We generate a unique ID for each health pill because the ID of each element + // in a web page must be unique, and each health pill generates a gradient + // that its code later refers to. + let healthPillId = 1; + let svgRootSelection = d3.select(svgRoot); + svgRootSelection.selectAll('g.nodeshape').each(function (nodeInfo: render.RenderNodeInfo) { + // Only show health pill data for this node if it is available. + const healthPills = nodeNamesToHealthPills[nodeInfo.node.name]; + const healthPill = healthPills ? healthPills[healthPillStepIndex] : null; + addHealthPill(this as SVGElement, healthPill!, nodeInfo, healthPillId++); + }); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/template.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/template.ts new file mode 100644 index 0000000000000000000000000000000000000000..7474f88a64a06a3c2acd55d850f2fa5ba28a0f1c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/template.ts @@ -0,0 +1,309 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {graphlib} from 'dagre'; +import * as _ from 'lodash'; +import { + GroupNode, + hasSimilarDegreeSequence, + Metanode, + NodeType, + OpNode, + SeriesNode, +} from './graph'; +import {Hierarchy} from './hierarchy'; + +export function detect( + h, + verifyTemplate +): { + [templateId: string]: string[]; +} { + // In any particular subgraph, there are either + // - leaf nodes (which do not have subgraph) + // - metanode nodes - some of them have only one member (singular metanode) + // and some have multiple members (non-singular metanode) + // First, generate a nearest neighbor hash of metanode nodes. + let nnGroups = clusterSimilarSubgraphs(h); + // For each metanode, compare its subgraph (starting from shallower groups) + // and assign template id. + let templates = groupTemplateAndAssignId(nnGroups, verifyTemplate); + // Sort the templates by minimum level in the graph at which they appear, + // as this leads to optimal setting of the colors of each template for + // maximum differentiation. + return < + { + [templateId: string]: string[]; + } + >Object.keys(templates) + .sort((key) => templates[key].level) + .reduce((obj, key) => { + obj[key] = templates[key]; + return obj; + }, {}); +} +/** + * @return Unique string for a metanode based on depth, |V|, |E| and + * op type histogram. + */ +function getSignature(metanode) { + // depth= |V|= |E|= + let props = _.map( + { + depth: metanode.depth, + '|V|': metanode.metagraph.nodes().length, + '|E|': metanode.metagraph.edges().length, + }, + function (v, k) { + return k + '=' + v; + } + ).join(' '); + // optype1=count1,optype2=count2 + let ops = _.map(metanode.opHistogram, function (count, op) { + return op + '=' + count; + }).join(','); + return props + ' [ops] ' + ops; +} +/** + * Generate a nearest neighbor hash of metanodes + * based on depth, |V|, |E|, and opHistogram of their subgraph + * (excluding leaf nodes and singular metanodes). + * @param graph The graph + * @return Array of pairs of [signature, + * Object with min level of the template and an Array of tf.graph.Group] + * sort by ascending order of minimum depth at which metanode appears. + */ +function clusterSimilarSubgraphs(h: Hierarchy) { + /** a dict from metanode.signature() => Array of tf.graph.Groups */ + const map = h.getNodeMap(); + let hashDict = Object.keys(map).reduce((reduced: Object, name: string) => { + const node: OpNode | GroupNode = map[name]; + if (node.type !== NodeType.META) { + return reduced; + } + let levelOfMetaNode = name.split('/').length - 1; + let signature = getSignature(node); + let templateInfo = reduced[signature] || { + nodes: [], + level: levelOfMetaNode, + }; + reduced[signature] = templateInfo; + templateInfo.nodes.push(node); + if (templateInfo.level > levelOfMetaNode) { + templateInfo.level = levelOfMetaNode; + } + return reduced; + }, {}); + return Object.keys(hashDict) + .map((key) => [key, hashDict[key]]) + .filter(([_, subGraph]) => { + const {nodes} = subGraph; + if (nodes.length > 1) { + // There is more than 1 node with this template. It is worth assigning + // a unique color to this template. + return true; + } + // If there is only 1 node with this template, only make a template for + // it if it represents a function. In that case, the graph explorer may + // add more nodes with the template later. + const node = nodes[0]; + return ( + node.type === NodeType.META && (node as Metanode).associatedFunction + ); + }) + .sort(([_, subGraph]) => { + // sort by depth + // (all members in the same nnGroup has equal depth) + return subGraph.nodes[0].depth; + }); +} +function groupTemplateAndAssignId(nnGroups, verifyTemplate) { + // For each metanode, compare its subgraph (starting from shallower groups) + // and assign template id. + let result: { + [templateId: string]: { + level: number; + nodes: string[]; + }; + } = {}; + return _.reduce( + nnGroups, + function (templates, nnGroupPair) { + let signature = nnGroupPair[0], + nnGroup = nnGroupPair[1].nodes, + clusters: {metanode: Metanode; members: string[]}[] = []; + nnGroup.forEach(function (metanode) { + // check with each existing cluster + for (let i = 0; i < clusters.length; i++) { + let similar = + !verifyTemplate || + isSimilarSubgraph( + clusters[i].metanode.metagraph, + metanode.metagraph + ); + // if similar, just add this metanode to the cluster + if (similar) { + // get template from the first one + metanode.templateId = clusters[i].metanode.templateId; + clusters[i].members.push(metanode.name); + return; + } + } + // otherwise create a new cluster with id 'signature [count] ' + metanode.templateId = signature + '[' + clusters.length + ']'; + clusters.push({ + metanode: metanode, + members: [metanode.name], + }); + }); + clusters.forEach(function (c) { + templates[c.metanode.templateId] = { + level: nnGroupPair[1].level, + nodes: c.members, + }; + }); + return templates; + }, + result + ); +} +function sortNodes(names: string[], graph: graphlib.Graph, prefix: string) { + return _.sortBy(names, [ + (name) => (graph.node(name) as unknown as OpNode).op, + (name) => (graph.node(name) as unknown as Metanode).templateId, + (name) => graph.neighbors(name)?.length, + (name) => graph.predecessors(name)?.length, + (name) => graph.successors(name)?.length, + (name) => name.substr(prefix.length), + ]); +} +function isSimilarSubgraph(g1: graphlib.Graph, g2: graphlib.Graph) { + if (!hasSimilarDegreeSequence(g1, g2)) { + return false; + } + // if we want to skip, just return true here. + // return true; + // Verify sequence by running DFS + let g1prefix = (g1.graph() as any).name; + let g2prefix = (g2.graph() as any).name; + let visited1 = {}; + let visited2 = {}; + let stack: {n1: string; n2: string}[] = []; + /** + * push sources or successors into the stack + * if the visiting pattern has been similar. + */ + function stackPushIfNotDifferent(n1, n2) { + let sub1 = n1.substr(g1prefix.length), + sub2 = n2.substr(g2prefix.length); + /* tslint:disable */ + if (visited1[sub1] ^ visited2[sub2]) { + console.warn( + 'different visit pattern', + '[' + g1prefix + ']', + sub1, + '[' + g2prefix + ']', + sub2 + ); + return true; + } + /* tslint:enable */ + if (!visited1[sub1]) { + // implied && !visited2[sub2] + visited1[sub1] = visited2[sub2] = true; + stack.push({n1: n1, n2: n2}); + } + return false; + } + // check if have same # of sources then sort and push + let sources1 = g1.sources(); + let sources2 = g2.sources(); + if (sources1.length !== sources2.length) { + /* tslint:disable */ + console.log('different source length'); + /* tslint:enable */ + return false; + } + sources1 = sortNodes(sources1 as any, g1, g1prefix); + sources2 = sortNodes(sources2 as any, g2, g2prefix); + for (let i = 0; i < sources1.length; i++) { + let different = stackPushIfNotDifferent(sources1[i], sources2[i]); + if (different) { + return false; + } + } + while (stack.length > 0) { + let cur = stack.pop(); + // check node + let similar = isSimilarNode( + g1.node(cur?.n1!) as any, + g2.node(cur?.n2!) as any + ); + if (!similar) { + return false; + } + // check if have same # of successors then sort and push + let succ1 = g1.successors(cur?.n1!), + succ2 = g2.successors(cur?.n2!); + if (succ1?.length! !== succ2?.length!) { + /* tslint:disable */ + console.log('# of successors mismatch', succ1, succ2); + /* tslint:enable */ + return false; + } + succ1 = sortNodes(succ1 as any, g1, g1prefix); + succ2 = sortNodes(succ2 as any, g2, g2prefix); + for (let j = 0; j < succ1?.length!; j++) { + let different = stackPushIfNotDifferent(succ1?.[j]!, succ2?.[j]!); + if (different) { + return false; + } + } + } + return true; +} +/** + * Returns if two nodes have identical structure. + */ +function isSimilarNode( + n1: OpNode | Metanode | SeriesNode, + n2: OpNode | Metanode | SeriesNode +): boolean { + if (n1.type === NodeType.META) { + // compare metanode + let metanode1 = n1; + let metanode2 = n2; + return ( + !!metanode1.templateId && + !!metanode2.templateId && + metanode1.templateId === metanode2.templateId + ); + } else if (n1.type === NodeType.OP && n2.type === NodeType.OP) { + // compare leaf node + return (n1).op === (n2).op; + } else if (n1.type === NodeType.SERIES && n2.type === NodeType.SERIES) { + // compare series node sizes and operations + // (only need to check one op as all op nodes are identical in series) + let sn1 = n1; + let sn2 = n2; + let seriesnode1Count = sn1.metagraph.nodeCount(); + return ( + seriesnode1Count === sn2.metagraph.nodeCount() && + (seriesnode1Count === 0 || + (sn1.metagraph.node(sn1.metagraph.nodes()[0])).op === + (sn2.metagraph.node(sn2.metagraph.nodes()[0])).op) + ); + } + return false; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-graph-icon.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-graph-icon.ts new file mode 100644 index 0000000000000000000000000000000000000000..9e598dacd563d3c18d2f7a20d736ec33a2f9bbbc --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-graph-icon.ts @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { computed, customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import { DarkModeMixin } from '../polymer/dark_mode_mixin'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import '../tf_dashboard_common/tensorboard-color'; +import { MetanodeColors, OpNodeColors, SeriesNodeColors } from './render'; + +export enum GraphIconType { + CONST = 'CONST', + META = 'META', + OP = 'OP', + SERIES = 'SERIES', + SUMMARY = 'SUMMARY', + API_LIST = 'API_LIST', +} +@customElement('tf-graph-icon') +class TfGraphIcon extends LegacyElementMixin(DarkModeMixin(PolymerElement)) { + static readonly template = html` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + `; + @property({ type: String }) + type: string; + @property({ + type: Boolean, + }) + vertical: boolean = false; + @property({ + type: String, + }) + fillOverride: string | null = null; + @property({ + type: String, + }) + strokeOverride: string | null = null; + @property({ + type: Number, + }) + height: number = 20; + @property({ + type: Boolean, + }) + faded: boolean = false; + getSvgDefinableElement(): HTMLElement { + return this.$.svgDefs as HTMLElement; + } + @computed('type', 'fillOverride') + get _fill(): string { + var type = this.type; + var fillOverride = this.fillOverride; + if (fillOverride != null) return fillOverride; + switch (type) { + case GraphIconType.META: + return MetanodeColors.DEFAULT_FILL; + case GraphIconType.SERIES: + return SeriesNodeColors.DEFAULT_FILL; + default: + return OpNodeColors.DEFAULT_FILL; + } + } + @computed('type', 'strokeOverride') + get _stroke(): string { + var type = this.type; + var strokeOverride = this.strokeOverride; + if (strokeOverride != null) return strokeOverride; + switch (type) { + case GraphIconType.META: + return MetanodeColors.DEFAULT_STROKE; + case GraphIconType.SERIES: + return SeriesNodeColors.DEFAULT_STROKE; + default: + return OpNodeColors.DEFAULT_STROKE; + } + } + /** + * Test whether the specified node's type, or the literal type string, + * match a particular other type. + */ + _isType(type: GraphIconType, targetType: GraphIconType): boolean { + return type === targetType; + } + _fadedClass(faded: boolean, shape: string) { + return faded ? 'faded-' + shape : ''; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-graph-scene.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-graph-scene.ts new file mode 100644 index 0000000000000000000000000000000000000000..42f3f9884b741ab0971c999986f6df73c142de34 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-graph-scene.ts @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as d3 from 'd3'; +import { Annotation, RenderNodeInfo } from './render'; + +type Selection = d3.Selection; +// This technically extends Polymer.Component whose constructor is not +// accessible. +export abstract class TfGraphScene extends HTMLElement { + maxMetanodeLabelLength: number; + maxMetanodeLabelLengthLargeFont: number; + maxMetanodeLabelLengthFontSize: number; + templateIndex: (name: string) => number | null; + abstract fire(eventName: string, daat: any): void; + abstract addNodeGroup(name: string, selection: Selection): void; + abstract removeNodeGroup(name: string): void; + abstract removeAnnotationGroup(annotation: Annotation, renderNode: RenderNodeInfo): void; + abstract isNodeExpanded(node: RenderNodeInfo): boolean; + abstract isNodeHighlighted(nodeName: string): boolean; + abstract isNodeSelected(nodeName: string): boolean; + abstract isNodeLinked(nodeName: string): boolean; + abstract getAnnotationGroupsIndex(name: string): Selection; + abstract getGraphSvgRoot(): SVGElement; + abstract getContextMenu(): HTMLElement; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-node-icon.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-node-icon.ts new file mode 100644 index 0000000000000000000000000000000000000000..9225ee7002670061f3131c64102c3bcba5b515f3 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/tf-node-icon.ts @@ -0,0 +1,176 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_scene_node from '../tf_graph_common/node'; +import './tf-graph-icon'; +import * as tf_graph_icon from './tf-graph-icon'; + +@customElement('tf-node-icon') +class TfNodeIcon extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + `; + + /** + * Node to represent with an icon. Optional, but if specified, its + * properties override those defined in the type, vertical, const and + * summary properties. + * This property is a tf.graph.Node. + */ + @property({ + type: Object, + }) + node: object | null = null; + + /** + * Render node information associated with this node. Optional. If + * specified, this is only used when computing the fill of the icon + * element. + * This property is a tf.graph.render.RenderNodeInfo. + */ + @property({ + type: Object, + }) + renderInfo: object | null = null; + + /** Type of node to draw (ignored if node is set). */ + @property({ + type: String, + }) + type: string | null = null; + + /** Direction for series (ignored for other types). */ + @property({ + type: Boolean, + }) + vertical: boolean = false; + + /** Whether the op is Const (ignored for non-ops). */ + @property({ + type: Boolean, + }) + const: boolean = false; + + /** Whether the op is a Summary (ignored for non-ops). */ + @property({ + type: Boolean, + }) + summary: boolean = false; + + /** + * Fill for the icon, optional. If fill is specified and node is not + * specified, then this value will override the default for the + * element. However, if node is specified, this value will be ignored. + */ + @property({ + type: String, + }) + fill: string | null = null; + + /** Height of the SVG element in pixels, used for scaling. */ + @property({ + type: Number, + }) + height: number = 20; + + @property({ + type: String, + computed: '_computeFillOverride(node, renderInfo, fill)', + observer: '_onFillOverrideChanged', + }) + _fillOverride: string; + + /** + * Returns fill value based on node and configuration. If any of those + * configurations or node value missing, it returns `fill` property. + * Note that if this evaluates to null, it will be chosen based on + * the type of the node. + */ + _computeFillOverride(inputNode, inputRenderInfo, inputFill) { + if (inputNode && inputRenderInfo) { + return tf_graph_scene_node.getFillForNode(inputRenderInfo); + } + return inputFill; + } + _getStrokeOverride(fillOverride) { + return fillOverride ? tf_graph_scene_node.getStrokeForFill(fillOverride) : null; + } + /** + * Returns graph-icon type from input, type, and summary. + */ + _getType(inputNode, isSummary, isConst, inputType) { + const { GraphIconType } = tf_graph_icon; + if (inputNode) { + switch (inputNode.type) { + case tf_graph.NodeType.OP: { + const opName = inputNode.op; + // TODO(tensorboarad-team): `op` should have a predictable type. + // Remove the type check. + if (typeof opName !== 'string') return GraphIconType.OP; + if (opName === 'Const' || isConst) return GraphIconType.CONST; + if (opName.endsWith('Summary') || isSummary) { + return GraphIconType.SUMMARY; + } + return GraphIconType.OP; + } + case tf_graph.NodeType.META: + return GraphIconType.META; + case tf_graph.NodeType.SERIES: + return GraphIconType.SERIES; + } + } + return inputType; + } + /** + * Test whether the specified node should be represented as a vertical + * series. Defaults to the value of the vertical property if node is + * not specified. + */ + _isVertical(inputNode, inputVertical) { + if (inputNode) { + return inputNode.hasNonControlEdges; + } + return !!inputVertical; + } + _getFaded(itemRenderInfo) { + return itemRenderInfo && itemRenderInfo.isFadedOut; + } + _onFillOverrideChanged(newFill, oldFill) { + const { node, renderInfo } = this; + if (newFill !== oldFill) { + tf_graph_scene_node.removeGradientDefinitions((this.$.icon as any).getSvgDefinableElement()); + } + if (node && renderInfo) { + tf_graph_scene_node.getFillForNode(renderInfo as any); + } + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/util.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/util.ts new file mode 100644 index 0000000000000000000000000000000000000000..965415cefaf624b764117ffccab53fcc6af3f07d --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_common/util.ts @@ -0,0 +1,444 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the 'License'); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an 'AS IS' BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * @fileoverview Utility functions for the tensorflow graph visualizer. + */ +import * as _ from 'lodash'; +import { notifyActionEventFromPolymer } from '../tb_debug'; +import { + GraphDebugActionEventId, + GraphDebugTimingEventId, + GRAPH_DEBUG_ACTION_EVENT_CATEGORY, + GRAPH_DEBUG_TIMING_EVENT_CATEGORY, +} from '../tb_debug/types'; +import { NodeStats, ProgressTracker } from './common'; + +const ASYNC_TASK_DELAY = 20; + +interface DebugTimingEvent { + timingId: GraphDebugTimingEventId; + // An associated duration in milliseconds for a timing event. + eventValue: number; +} + +interface DebugActionEvent { + actionId: GraphDebugActionEventId; + eventLabel?: string; +} + +export type DebugEvent = DebugTimingEvent | DebugActionEvent; + +function isDebugTimingEvent(debugEvent: DebugEvent): debugEvent is DebugTimingEvent { + return debugEvent.hasOwnProperty('timingId'); +} + +export function notifyDebugEvent(debugEvent: DebugEvent) { + if (isDebugTimingEvent(debugEvent)) { + notifyActionEventFromPolymer({ + eventCategory: GRAPH_DEBUG_TIMING_EVENT_CATEGORY, + eventAction: debugEvent.timingId, + eventValue: debugEvent.eventValue, + }); + } else { + notifyActionEventFromPolymer({ + eventCategory: GRAPH_DEBUG_ACTION_EVENT_CATEGORY, + eventAction: debugEvent.actionId, + eventLabel: debugEvent.eventLabel, + }); + } +} + +/** + * Measure and log a synchronous task. + */ +export function time(msg: string, task: () => T, debugEventId?: GraphDebugTimingEventId) { + let start = Date.now(); + let result = task(); + const durationInMs = Date.now() - start; + /* tslint:disable */ + console.log(msg, ':', durationInMs, 'ms'); + /* tslint:enable */ + if (debugEventId) { + notifyDebugEvent({ timingId: debugEventId, eventValue: durationInMs }); + } + return result; +} +/** + * Creates a tracker that sets the progress property of the + * provided polymer component. The provided component must have + * a property called 'progress' that is not read-only. The progress + * property is an object with a numerical 'value' property and a + * string 'msg' property. + */ +export function getTracker(polymerComponent: any): ProgressTracker { + return { + setMessage: function (msg) { + polymerComponent.set('progress', { + value: polymerComponent.progress.value, + msg: msg, + }); + }, + updateProgress: function (value) { + polymerComponent.set('progress', { + value: polymerComponent.progress.value + value, + msg: polymerComponent.progress.msg, + }); + }, + reportError: function (msg: string, err) { + // Log the stack trace in the console. + console.error(err.stack); + // And send a user-friendly message to the UI. + polymerComponent.set('progress', { + value: polymerComponent.progress.value, + msg: msg, + error: true, + }); + }, + }; +} +/** + * Creates a tracker for a subtask given the parent tracker, the total + * progress + * of the subtask and the subtask message. The parent task should pass a + * subtracker to its subtasks. The subtask reports its own progress which + * becomes relative to the main task. + */ +export function getSubtaskTracker( + parentTracker: ProgressTracker, + impactOnTotalProgress: number, + subtaskMsg: string, +): ProgressTracker { + return { + setMessage: function (progressMsg) { + // The parent should show a concatenation of its message along with + // its subtask tracker message. + parentTracker.setMessage(subtaskMsg + ': ' + progressMsg); + }, + updateProgress: function (incrementValue) { + // Update the parent progress relative to the child progress. + // For example, if the sub-task progresses by 30%, and the impact on the + // total progress is 50%, then the task progresses by 30% * 50% = 15%. + parentTracker.updateProgress((incrementValue * impactOnTotalProgress) / 100); + }, + reportError: function (msg: string, err: Error) { + // The parent should show a concatenation of its message along with + // its subtask error message. + parentTracker.reportError(subtaskMsg + ': ' + msg, err); + }, + }; +} +/** + * Runs a synchronous expensive task and return the result. + * Please use runAsyncPromiseTask in case a task returns a Promise. + */ +export function runTask( + msg: string, + incProgressValue: number, + task: () => T, + tracker: ProgressTracker, + debugEventId?: GraphDebugTimingEventId, +): T { + // Update the progress message to say the current running task. + tracker.setMessage(msg); + // Run the expensive task with a delay that gives enough time for the + // UI to update. + try { + let result = time(msg, task, debugEventId); + // Update the progress value. + tracker.updateProgress(incProgressValue); + // Return the result to be used by other tasks. + return result; + } catch (e: any) { + // Errors that happen inside asynchronous tasks are + // reported to the tracker using a user-friendly message. + tracker.reportError('Failed ' + msg, e); + return null!; + } +} +/** + * Runs an expensive task asynchronously and returns a promise of the result. + */ +export function runAsyncTask( + msg: string, + incProgressValue: number, + task: () => T, + tracker: ProgressTracker | undefined, + debugEventId?: GraphDebugTimingEventId, +): Promise { + return new Promise((resolve, reject) => { + // Update the progress message to say the current running task. + tracker && tracker.setMessage(msg); + // Run the expensive task with a delay that gives enough time for the + // UI to update. + setTimeout(function () { + try { + let result = time(msg, task, debugEventId); + // Update the progress value. + tracker && tracker.updateProgress(incProgressValue); + // Return the result to be used by other tasks. + resolve(result); + } catch (e: any) { + // Errors that happen inside asynchronous tasks are + // reported to the tracker using a user-friendly message. + tracker && tracker.reportError('Failed ' + msg, e); + } + }, ASYNC_TASK_DELAY); + }); +} +/** + * Asynchronously runs an expensive task that returns a promise. Updates the + * tracker's progress after the promise resolves. Returns a new promise that + * resolves after the progress is updated. + */ +export function runAsyncPromiseTask( + msg: string, + incProgressValue: number, + task: () => Promise, + tracker: ProgressTracker, + debugEventId?: GraphDebugTimingEventId, +): Promise { + return new Promise((resolve, reject) => { + let handleError = function (e) { + // Errors that happen inside asynchronous tasks are + // reported to the tracker using a user-friendly message. + tracker.reportError('Failed ' + msg, e); + reject(e); + }; + // Update the progress message to say the current running task. + tracker.setMessage(msg); + // Run the expensive task with a delay that gives enough time for the + // UI to update. + setTimeout(function () { + try { + let start = Date.now(); + task() + .then(function (value) { + const durationInMs = Date.now() - start; + /* tslint:disable */ + console.log(msg, ':', durationInMs, 'ms'); + // Update the progress value. + tracker.updateProgress(incProgressValue); + notifyDebugEvent({ + timingId: debugEventId!, + eventValue: durationInMs, + }); + // Return the result to be used by other tasks. + resolve(value); + }) + .catch(handleError); + } catch (e) { + handleError(e); + } + }, ASYNC_TASK_DELAY); + }); +} +/** + * Returns a query selector with escaped special characters that are not + * allowed in a query selector. + */ +export function escapeQuerySelector(querySelector: string): string { + return querySelector.replace(/([:.\[\],/\\\(\)])/g, '\\$1'); +} +interface Unit { + symbol: string; + numUnits?: number; +} +type Units = ReadonlyArray; +// For unit conversion. +export const MEMORY_UNITS: Units = [ + // Atomic unit. + { symbol: 'B' }, + // numUnits specifies how many previous units this unit contains. + { symbol: 'KB', numUnits: 1024 }, + { symbol: 'MB', numUnits: 1024 }, + { symbol: 'GB', numUnits: 1024 }, + { symbol: 'TB', numUnits: 1024 }, + { symbol: 'PB', numUnits: 1024 }, +]; +export const TIME_UNITS: Units = [ + // Atomic unit. Finest granularity in TensorFlow stat collection. + { symbol: 'µs' }, + // numUnits specifies how many previous units this unit contains. + { symbol: 'ms', numUnits: 1000 }, + { symbol: 's', numUnits: 1000 }, + { symbol: 'min', numUnits: 60 }, + { symbol: 'hr', numUnits: 60 }, + { symbol: 'days', numUnits: 24 }, +]; +/** + * Returns the human readable version of the unit. + * (e.g. 1.35 GB, 23 MB, 34 ms, 6.53 min etc). + */ +export function convertUnitsToHumanReadable(value: number, units: Units, unitIndex: number = 0) { + if (unitIndex + 1 < units.length && value >= units[unitIndex + 1].numUnits!) { + return convertUnitsToHumanReadable(value / units[unitIndex + 1].numUnits!, units, unitIndex + 1); + } + // toPrecision() has the tendency to return a number in scientific + // notation and casting back to a number brings it back to normal notation. + // e.g., + // > value = 213; value.toPrecision(1) + // < "2e+2" + // > Number(value.toPrecision(1)) + // < 200 + return Number(value.toPrecision(3)) + ' ' + units[unitIndex].symbol; +} +export function hasDisplayableNodeStats(stats: NodeStats) { + if (stats && (stats.totalBytes > 0 || stats.getTotalMicros() > 0 || stats.outputSize)) { + return true; + } + return false; +} +/** + * Given a list of strings, it returns a new list of strings with the longest + * common prefix removed. If the common prefix is one of the strings in the + * list, it returns the original strings. + */ +export function removeCommonPrefix(strings: string[]) { + if (strings.length < 2) { + return strings; + } + let index = 0; + let largestIndex = 0; + // Find the shortest name across all strings. + let minLength = _.min(_.map(strings, (str) => str.length))!; + while (true) { + index++; + let prefixes = _.map(strings, (str) => str.substring(0, index)); + let allTheSame = prefixes.every((prefix, i) => { + return i === 0 ? true : prefix === prefixes[i - 1]; + }); + if (allTheSame) { + if (index >= minLength) { + // There is a string whose whole name is a prefix to other string. + // In this case, we return the original list of string. + return strings; + } + largestIndex = index; + } else { + break; + } + } + return _.map(strings, (str) => str.substring(largestIndex)); +} +/** + * Given a timestamp in microseconds, return a human-friendly string denoting + * how long ago the timestamp was. + */ +export function computeHumanFriendlyTime(timeInMicroseconds: number) { + var timeDifferenceInMs = +new Date() - +new Date(timeInMicroseconds / 1000); + if (timeDifferenceInMs < 30000) { + return 'just now'; + } else if (timeDifferenceInMs < 60000) { + return Math.floor(timeDifferenceInMs / 1000) + ' seconds ago'; + } else if (timeDifferenceInMs < 120000) { + return 'a minute ago'; + } else if (timeDifferenceInMs < 3600000) { + return Math.floor(timeDifferenceInMs / 60000) + ' minutes ago'; + } else if (Math.floor(timeDifferenceInMs / 3600000) == 1) { + return 'an hour ago'; + } else if (timeDifferenceInMs < 86400000) { + return Math.floor(timeDifferenceInMs / 3600000) + ' hours ago'; + } else if (timeDifferenceInMs < 172800000) { + return 'yesterday'; + } + return Math.floor(timeDifferenceInMs / 86400000) + ' days ago'; +} + +const canvas = document.createElement('canvas'); +const measurerContext = canvas.getContext('2d'); + +/** + * Returns width of `text` rendered with Roboto at provided fontSize. + */ +export function measureTextWidth(text: string, fontSize: number): number { + if (measurerContext) measurerContext.font = `${fontSize}px Roboto, sans-serif`; + return measurerContext?.measureText(text).width!; +} + +/** + * Returns, if rendered `text` does not fit into maxWidth, truncated string with trailing + * ellipsis. + */ +export function maybeTruncateString(text: string, fontSize: number, maxWidth: number): string { + if (!text) return ''; + if (measureTextWidth(text, fontSize) <= maxWidth) return text; + + let start = 0; + let end = text.length; + while (start < end) { + const middle = start + Math.round((end - start) / 2); + const substring = text.slice(0, middle) + '…'; + if (measureTextWidth(substring, fontSize) <= maxWidth) { + start = middle; + } else { + end = middle - 1; + } + } + + return start === 0 ? text[0] : text.slice(0, start) + '…'; +} + +/** + * Extend this subclass to receive event dispatching traits. + * Useful for when various locations need to observe changes on + * a common instance, who has a limited lifetime. + * + * This is not intended for use with framework-supported elements. + * For example, prefer using `@Output myEmitter` on Angular + * Components, or Polymer's `on-myprop-changed` for Polymer + * elements, instead. + * + * Example usage: + * + * ``` + * export enum ReactorEvent {EXPLODED} + * export class Reactor extends Dispatcher { + * _update() { + * this.dispatchEvent(ReactorEvent.EXPLODED); + * } + * } + * + * // Elsewhere + * const r = new Reactor(); + * r.addEventListener(ReactorEvent.EXPLODED, this._cleanup); + * ``` + */ +export class Dispatcher { + private eventTypeToListeners = new Map(); + + private getListeners(eventType) { + if (!this.eventTypeToListeners.has(eventType)) { + this.eventTypeToListeners.set(eventType, []); + } + return this.eventTypeToListeners.get(eventType); + } + + addListener(eventType: EventType, listener: Function) { + this.getListeners(eventType)?.push(listener); + } + + removeListener(eventType: EventType, listener: Function) { + const newListeners = this.getListeners(eventType)?.filter((x) => { + return x !== listener; + }); + this.eventTypeToListeners.set(eventType, newListeners!); + } + + dispatchEvent(eventType: EventType, payload?: any) { + for (const listener of this.getListeners(eventType)!) { + listener(payload); + } + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts new file mode 100644 index 0000000000000000000000000000000000000000..1e4683440ad53b2b9137a691b75b9d4676e8aa60 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_controls/tf-graph-controls.ts @@ -0,0 +1,1845 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, property, observe } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as _ from 'lodash'; +import { DarkModeMixin } from '../polymer/dark_mode_mixin'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tb_debug from '../tb_debug'; +import '../tf_dashboard_common/tensorboard-color'; +import * as tf_graph_common from '../tf_graph_common/common'; +import * as tf_graph_proto from '../tf_graph_common/proto'; +import * as tf_graph_render from '../tf_graph_common/render'; +import '../tf_graph_common/tf-graph-icon'; +import * as tf_graph_util from '../tf_graph_common/util'; +import '../tf_graph_node_search/tf-graph-node-search'; +import { getRouter } from '../tf_backend/router'; +import { fetchPbTxt, parseGraphPbTxt } from '../tf_graph_common/parser'; +import '../tf_graph_loader/tf-graph-dashboard-loader'; +import * as tf_graph_parser from '../tf_graph_common/parser'; +import * as tf_graph_node from '../tf_graph_common/node'; +import * as tf_node_info from '../tf_graph_info/tf-node-info'; + +export interface Selection { + run: string; + tag: string | null; + type: tf_graph_common.SelectionType; + batch: number; + step: number; +} +// TODO(stephanwlee): Move this to tf-graph-dashboard +export interface TagItem { + tag: string | null; + displayName: string; + conceptualGraph: boolean; + opGraph: boolean; + profile: boolean; +} +// TODO(stephanwlee): Move this to tf-graph-dashboard +export interface RunItem { + name: string; + tags: TagItem[]; +} +// TODO(stephanwlee): Move this to tf-graph-dashboard +export type Dataset = Array; +@customElement('tf-graph-controls') +class TfGraphControls extends LegacyElementMixin(DarkModeMixin(PolymerElement)) { + static readonly template = html` + + +
+ + + + +
+
+
+
+ + + Fit to screen + +
+
+ + + Download PNG + +
+ +
+
Run ([[datasets.length]])
+ + + + + +
+ +
+ + +
+
+
+ + + +
+ + + +
+
+ + Module or Operators +
+
+ + +
+ + + + + + + + + + + + + + +
+ + + Unexpanded Module or Operators +
+ ? + +
+ Unexpandable Node: It can be an Api, operator or module. It cannot be expanded because it has no + subnodes +
+
+
+
+ + + Api List +
+ ? + +
Apis between modules
+
+
+
+ +
+
+
+
+
+ + + +
+ + +
+
+
+
+ `; + @property({ + type: Array, + observer: '_datasetsChanged', + }) + datasets: any = []; + @property({ + type: Object, + }) + colorset; + @property({ + type: Object, + }) + matchednodeset: any = []; + @property({ + type: Object, + }) + unMatchednodeset: any = []; + @property({ + type: Object, + }) + colorSetChanged; + @property({ + type: Array, + }) + selectColor: any = []; + /** + * @type {tf_graph_render.RenderGraphInfo} + */ + @property({ + type: Object, + }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + /** + * @type {!Selection} + */ + @property({ + type: Object, + notify: true, + readOnly: true, + computed: + '_computeSelection(datasets, _selectedRunIndex, _selectedTagIndex, _selectedGraphType, _selectedBatch, _selectedStep)', + }) + selection: object; + @property({ + type: Object, + notify: true, + }) + selectedFile: object; + @property({ + type: Object, + }) + _selectedColors; + @property({ + type: Number, + observer: '_selectedRunIndexChanged', + }) + _selectedRunIndex: number = 0; + @property({ + type: Number, + }) + _selectedBatch: number = -1; + @property({ + type: Number, + }) + _selectedStep: number = -1; + @property({ + type: Boolean, + notify: true, + }) + traceInputs: boolean = false; + @property({ + type: Number, + observer: '_selectedTagIndexChanged', + }) + _selectedTagIndex: number = 0; + /** + * @type {tf_graph_common.SelectionType} + */ + @property({ + type: String, + }) + _selectedGraphType: string = tf_graph_common.SelectionType.OP_GRAPH; + @property({ + type: String, + notify: true, + }) + selectedNode: string; + selectedNPUNode: string = ''; + selectedBenchNode: string = ''; + selectedMatchedNPUNode: string = ''; + selectedMatchedBenchNode: string = ''; + @property({ + type: Boolean, + }) + showSessionRunsDropdown: boolean = true; + @property({ + type: Boolean, + }) + showUploadButton: boolean = false; + @property({ + type: Boolean, + }) + _expanded: boolean = true; + @property({ + type: Boolean, + }) + _legendOpened: boolean = true; + _Colors: boolean = true; + @property({ type: Object }) + _downloadFilename = 'graph.png'; + menuData: any; + graphDef: any; + @property({ type: Object, observer: '_newMenu' }) + microsteps: any; + steplist: any; + @property({ type: Object, observer: '_updateColorItems' }) + colors: any; + @property({ type: Object }) + nodeName: any; + @property({ type: Object }) + menu: any; + @property({ type: Object }) + unmatched: any = []; + NPU_unmatched: any = []; + Bench_unmatched: any = []; + NPU_matched: any = []; + Bench_matched: any = []; + matched: any = []; + matchedlist: any = []; + @property({ type: Object }) + precisionmenu: any = []; + @property({ type: Number }) + selectedSide: number = 0; + _selectedNpuMatchMenu: number = -1; + _selectedBenchMatchMenu: number = -1; + _selectedUnMatchMenu: number = -1; + + getMenuNodeName() { + return this.nodeName; + } + _showTabContent(buttonText, contentId) { + // Remove 'active' class from all buttons + this.shadowRoot?.querySelectorAll('.tab-button').forEach((button) => { + button.classList.remove('active'); + }); + + // Add 'active' class to the clicked button + const buttons = this.shadowRoot?.querySelectorAll('.tab-button'); + buttons?.forEach((button) => { + if ((button as HTMLElement).innerHTML === buttonText) { + button?.classList.add('active'); + } + }); + + // Hide all content + this.shadowRoot?.querySelectorAll('.tab-content').forEach((content) => { + content.classList.add('hidden'); + }); + + // Show the selected content + const selectedContent = this.shadowRoot?.getElementById(contentId); + if (selectedContent) { + selectedContent.classList.remove('hidden'); + } + } + + // 使用示例 + _showNodeControls() { + this._showTabContent('设置', 'nodes-content'); + } + + _showDirectoryStructure() { + this._showTabContent('目录', 'directory-content'); + } + + _showSearchStructure() { + this._showTabContent('搜索', 'search-content'); + } + + _showMatch() { + this._showTabContent('匹配', 'match-content'); + } + + _onGraphTypeChangedByUserGesture() { + tf_graph_util.notifyDebugEvent({ + actionId: tb_debug.GraphDebugEventId.GRAPH_TYPE_CHANGED, + eventLabel: this._selectedGraphType, + }); + } + + _numTags(datasets: Dataset, _selectedRunIndex: number) { + return this._getTags(datasets, _selectedRunIndex).length; + } + _getTags(datasets: Dataset, _selectedRunIndex: number) { + if (!datasets || !datasets[_selectedRunIndex]) { + return []; + } + return datasets[_selectedRunIndex].tags; + } + _fit() { + this.fire('fit-tap'); + } + @observe('colorset') + _observe() { + if (this.colorset.length !== 0) { + const colorsets = this.colorset; + for (const item of colorsets) { + if (item[1].value.length === 0) { + item[1].value.push('无匹配节点'); + } + } + this.colorSetChanged = colorsets; + } else { + return; + } + } + @observe('unmatched') + _observeUnmatchedNode() { + this.set('NPU_unmatched', this.unmatched[0]); + this.set('Bench_unmatched', this.unmatched[1]); + } + @observe('matchedlist', 'selection') + _observeMatchedList() { + this.set('NPU_matched', []); + this.set('Bench_matched', []); + this.set('matched', []); + if (this.matchedlist) { + for (const item of this.matchedlist) { + this.NPU_matched = [...this.NPU_matched, item[0]]; + this.Bench_matched = [...this.Bench_matched, item[1]]; + this.matched = [...this.matched, [item[0], item[1]]]; + } + this.set('NPU_matched', this.NPU_matched); + this.set('Bench_matched', this.Bench_matched); + this.set('matched', this.matched); + } + } + _observePrecsionNode(event) { + let prefix = ''; + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + if (hasBNode) { + prefix = 'N___'; + } + const node = prefix + event.model.item; + this.set('selectedNode', node); + } + _observeNPUUnMatchedNode(event) { + let prefix = ''; + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + if (hasBNode) { + prefix = 'N___'; + } + const node = prefix + event.model.item; + this.set('selectedNPUNode', node); + this.set('selectedNode', node); + } + _observeBenchUnMatchedNode(event) { + let prefix = ''; + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + if (hasBNode) { + prefix = 'B___'; + } + const node = prefix + event.model.item; + this.set('selectedBenchNode', node); + this.set('selectedNode', node); + } + _observeNPUMatchedNode(event) { + let prefix = ''; + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + if (hasBNode) { + prefix = 'N___'; + } + const node = prefix + event.model.item; + this.set('selectedMatchedNPUNode', node); + const matched_node = this.findNodeInMatched(node.slice(4), 0); + if (matched_node) { + this.set('selectedMatchedBenchNode', `B___${matched_node}`); + } + this.set('selectedNode', node); + } + _observeBenchMatchedNode(event) { + let prefix = ''; + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + if (hasBNode) { + prefix = 'B___'; + } + const node = prefix + event.model.item; + this.set('selectedMatchedBenchNode', node); + const matched_node = this.findNodeInMatched(node.slice(4), 1); + if (matched_node) { + this.set('selectedMatchedNPUNode', `N___${matched_node}`); + } + this.set('selectedNode', node); + } + findNodeInMatched(node, side) { + // 遍历 matched 数组 + for (let i = 0; i < this.matched.length; i++) { + // 获取当前子数组 + const pair = this.matched[i]; + // 确保子数组的长度为 2,防止越界 + if (pair.length >= 2 && pair[side] === node) { + if (side == 0) { + return pair[1]; // 返回第二项 + } else { + return pair[0]; // 返回第一项 + } + } + } + // 如果没找到 node + return null; // 返回 null 或其他指示未找到的值 + } + _searchNpu() { + this.selectedSide = 0; + } + _searchBench() { + this.selectedSide = 1; + } + showDynamicDialog(message) { + // 检查是否已经有显示的对话框,避免重复添加 + let existingDialog = this.shadowRoot?.querySelector('#dynamicDialog'); + if (existingDialog) { + existingDialog.remove(); // 删除旧的对话框 + } + // 创建新的对话框 + const dialog = document.createElement('paper-dialog'); + dialog.id = 'dynamicDialog'; + // 添加标题 + const title = document.createElement('h2'); + title.textContent = '提示'; + dialog.appendChild(title); + // 添加提示内容 + const content = document.createElement('div'); + content.textContent = message; + dialog.appendChild(content); + // 添加按钮 + const buttonContainer = document.createElement('div'); + buttonContainer.classList.add('buttons'); + const closeButton = document.createElement('paper-button'); + closeButton.setAttribute('dialog-dismiss', ''); + closeButton.textContent = '关闭'; + buttonContainer.appendChild(closeButton); + dialog.appendChild(buttonContainer); + // 添加到 shadow DOM + this.shadowRoot?.appendChild(dialog); + // 打开对话框 + dialog.open(); + } + + _handleUnmatchSearch(event) { + const action = event.target.getAttribute('data-action'); + const menuFirstRow = this.menu[1]; + const selectedNode = this.selectedNode; + if (this.Bench_unmatched.length === 0) { + this.showDynamicDialog('标杆侧没有未匹配节点'); + return; + } + /* 如果用户没选中节点,默认选中 precisionmenu 的第一个节点, + 获取 selectedNode 在 menuFirstRow 中的索引,不存在则默认选中 precisionmenu 的第一个节点*/ + const startIndex = selectedNode ? menuFirstRow.indexOf(selectedNode.slice(4)) : -1; + if (!selectedNode || startIndex === -1) { + this.set('selectedNode', `B___${this.Bench_unmatched[0]}`); + return; + } + const findNextNode = () => { + if (this.Bench_unmatched.includes(selectedNode)) { + const currentIndex = this.Bench_unmatched.indexOf(selectedNode); + if (currentIndex + 1 >= this.Bench_unmatched.length) { + this.showDynamicDialog('没有下一个未匹配节点'); + return null; + } + return this.Bench_unmatched[currentIndex + 1]; + } + for (let i = startIndex + 1; i < menuFirstRow.length; i++) { + if (this.Bench_unmatched.includes(menuFirstRow[i])) { + return menuFirstRow[i]; + } + } + this.showDynamicDialog('没有下一个未匹配节点'); + return null; + }; + const findPreviousNode = () => { + if (this.Bench_unmatched.includes(selectedNode)) { + const currentIndex = this.Bench_unmatched.indexOf(selectedNode); + if (currentIndex === 0) { + this.showDynamicDialog('没有上一个未匹配节点'); + return null; + } + return this.Bench_unmatched[currentIndex - 1]; + } + + for (let i = startIndex - 1; i >= 0; i--) { + if (this.Bench_unmatched.includes(menuFirstRow[i])) { + return menuFirstRow[i]; + } + } + this.showDynamicDialog('没有上一个未匹配节点'); + return null; + }; + // 执行查找操作 + const nextNode = action === 'next' ? findNextNode() : findPreviousNode(); + if (nextNode) { + this.set('selectedNode', `B___${nextNode}`); + } + } + + _handlePrecisonSearch(event) { + const action = event.target.getAttribute('data-action'); + const menuFirstRow = this.menu[0]; + const selectedNode = this.selectedNode; + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + const showDialog = (message: string) => { + this.showDynamicDialog(message); + }; + // 设置 selectedNode 方法封装 + const setDefaultNode = () => { + const defaultNode = hasBNode ? `N___${this.precisionmenu[0]}` : this.precisionmenu[0]; + this.set('selectedNode', defaultNode); + }; + // 校验用户是否选择了颜色 + if (this.selectColor.length === 0) { + showDialog('请选择颜色'); + return; + } + // 校验 precisionmenu 是否为空 + if (this.precisionmenu.length === 0) { + showDialog('选择的颜色没有节点存在'); + return; + } + // 如果用户未选中节点,设置默认节点 + if (!selectedNode) { + setDefaultNode(); + return; + } + // 获取 selectedNode 在 menuFirstRow 中的索引 + const slicedNode = hasBNode ? selectedNode.slice(4) : selectedNode; + const startIndex = menuFirstRow.indexOf(slicedNode); + // 如果索引无效,设置默认节点 + if (startIndex === -1) { + setDefaultNode(); + return; + } + // 根据操作类型查找下一个或上一个节点 + const findNextNode = () => { + if (this.precisionmenu.includes(selectedNode)) { + const currentIndex = this.precisionmenu.indexOf(selectedNode); + if (currentIndex + 1 >= this.precisionmenu.length) { + this.showDynamicDialog('没有下一个问题节点'); + return null; + } + return this.precisionmenu[currentIndex + 1]; + } + for (let i = startIndex + 1; i < menuFirstRow.length; i++) { + if (this.precisionmenu.includes(menuFirstRow[i])) { + return menuFirstRow[i]; + } + } + this.showDynamicDialog('没有下一个问题节点'); + return null; + }; + const findPreviousNode = () => { + if (this.precisionmenu.includes(selectedNode)) { + const currentIndex = this.precisionmenu.indexOf(selectedNode); + if (currentIndex === 0) { + this.showDynamicDialog('没有上一个问题节点'); + return null; + } + return this.precisionmenu[currentIndex - 1]; + } + + for (let i = startIndex - 1; i >= 0; i--) { + if (this.precisionmenu.includes(menuFirstRow[i])) { + return menuFirstRow[i]; + } + } + this.showDynamicDialog('没有上一个问题节点'); + return null; + }; + // 执行查找操作 + const nextNode = action === 'next' ? findNextNode() : findPreviousNode(); + if (nextNode) { + if (hasBNode) { + this.set('selectedNode', `N___${nextNode}`); + } else { + this.set('selectedNode', nextNode); + } + } + } + + async _handleMatchedNodesClick(this) { + // 打开弹窗 + if (this.selectedNPUNode == '' || this.selectedBenchNode == '') { + this.showDynamicDialog('节点不可匹配'); + return; + } + const params = new URLSearchParams(); + const run = this.datasets[this._selectedRunIndex].name; + const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; + params.set('run', run); + if (tag) params.set('tag', tag); + params.set('NPU', this.selectedNPUNode); + params.set('Bench', this.selectedBenchNode); + const precisionPath = getRouter().pluginRouteForSrc('graph_ascend', '/match', params); + const precisionStr = await tf_graph_parser.fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer + const decoder = new TextDecoder(); + const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 + const dialogMessageMap: { [key: string]: string } = { + inputshape: '匹配失败,inputshape不匹配', + outputshape: '匹配失败,outputshape不匹配', + inputNone: '匹配失败,input中包含None值', + outputNone: '匹配失败,output中包含None值', + }; + + const message = dialogMessageMap[decodedStr]; + if (message) { + this.showDynamicDialog(message); // 使用动态生成的对话框 + } + let resultArray: any[] = []; + resultArray = JSON.parse(decodedStr) as any[]; + if (resultArray.length !== 0) { + this.push('matchednodeset', resultArray); + if (this.unMatchednodeset.length !== 0) { + const has = this.unMatchednodeset.indexOf(this.selectedNPUNode); + if (has !== -1) { + this.unMatchednodeset = [...this.unMatchednodeset.slice(0, has), ...this.unMatchednodeset.slice(has + 1)]; + } + } + tf_graph_node.getMatched(this.matchednodeset); + tf_node_info.getMatched(this.matchednodeset); + const index_N = this.NPU_unmatched.indexOf(this.selectedNPUNode.slice(4)); + if (index_N !== -1) { + this.splice('NPU_unmatched', index_N, 1); + this.notifyPath('NPU_unmatched'); + } + const index_B = this.Bench_unmatched.indexOf(this.selectedBenchNode.slice(4)); + if (index_B !== -1) { + this.splice('Bench_unmatched', index_B, 1); + this.notifyPath('Bench_unmatched'); + } + this.set('_selectedNpuMatchMenu', -1); + this.set('_selectedBenchMatchMenu', -1); + this.NPU_matched = [...this.NPU_matched, this.selectedNPUNode.slice(4)]; + this.Bench_matched = [...this.Bench_matched, this.selectedBenchNode.slice(4)]; + this.matched = [...this.matched, [this.selectedNPUNode.slice(4), this.selectedBenchNode.slice(4)]]; + this.showDynamicDialog('节点匹配成功'); + this.set('selectedNode', ''); + this.set('selectedNode', this.selectedNPUNode); + this.selectedNPUNode = ''; + this.selectedBenchNode = ''; + } else { + this.showDynamicDialog('节点不可匹配'); + } + } + + async _handleUnMatchedNodesClick(this) { + // 打开弹窗 + if (this.selectedMatchedNPUNode == '' || this.selectedMatchedBenchNode == '') { + this.showDynamicDialog('取消匹配失败,请核对选择节点'); + return; + } + const existsInMatch = this.matched.some( + ([NPU_matched, Bench_matched]) => + NPU_matched === this.selectedMatchedNPUNode.slice(4) && + Bench_matched === this.selectedMatchedBenchNode.slice(4), + ); + if (!existsInMatch) { + this.showDynamicDialog('取消匹配失败,请核对选择节点'); + return; + } + this.NPU_unmatched.push(this.selectedMatchedNPUNode.slice(4)); + this.NPU_unmatched = [...this.NPU_unmatched]; + this.notifyPath('NPU_unmatched'); + this.Bench_unmatched.push(this.selectedMatchedBenchNode.slice(4)); + this.Bench_unmatched = [...this.Bench_unmatched]; + this.notifyPath('Bench_unmatched'); + const index_N = this.NPU_matched.indexOf(this.selectedMatchedNPUNode.slice(4)); + if (index_N !== -1) { + this.NPU_matched = [...this.NPU_matched.slice(0, index_N), ...this.NPU_matched.slice(index_N + 1)]; + } + const index_B = this.Bench_matched.indexOf(this.selectedMatchedBenchNode.slice(4)); + if (index_B !== -1) { + this.Bench_matched = [...this.Bench_matched.slice(0, index_B), ...this.Bench_matched.slice(index_B + 1)]; + } + const index_M = this.matched.findIndex( + (item) => item[0] === this.selectedMatchedNPUNode.slice(4) && item[1] === this.selectedMatchedBenchNode.slice(4), + ); + if (index_M !== -1) { + this.matched = [...this.matched.slice(0, index_M), ...this.matched.slice(index_M + 1)]; + } + const index_U = this.matchednodeset.findIndex((item) => item[0] === this.selectedMatchedNPUNode); + if (index_U !== -1) { + this.matchednodeset = [...this.matchednodeset.slice(0, index_U), ...this.matchednodeset.slice(index_U + 1)]; + } else { + this.unMatchednodeset.push(this.selectedMatchedNPUNode); + } + this.set('_selectedUnMatchMenu', -1); + const params = new URLSearchParams(); + const run = this.datasets[this._selectedRunIndex].name; + const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; + params.set('run', run); + if (tag) params.set('tag', tag); + params.set('NPU', this.selectedMatchedNPUNode); + params.set('Bench', this.selectedMatchedBenchNode); + const precisionPath = getRouter().pluginRouteForSrc('graph_ascend', '/unmatch', params); + const precisionStr = await tf_graph_parser.fetchPbTxt(precisionPath); // 获取异步的 ArrayBuffer + const decoder = new TextDecoder(); + const decodedStr = decoder.decode(precisionStr); // 解码 ArrayBuffer 到字符串 + tf_graph_node.getMatched(this.matchednodeset); + tf_node_info.getMatched(this.matchednodeset); + tf_graph_node.getUnMatched(this.unMatchednodeset); + tf_node_info.getUnMatched(this.unMatchednodeset); + this.set('selectedNode', ''); + this.showDynamicDialog('已取消匹配'); + } + // 写一个如果切换数据清除所有checkbox和所有this.selectColor + @observe('selection') + _clearAllToggleCheckbox() { + this._searchNpu(); + const dropdown = this.$.dropdown as any; + dropdown.selected = 0; + const allCheckboxes = this.shadowRoot?.querySelectorAll('paper-checkbox'); + if (allCheckboxes) { + allCheckboxes.forEach((checkbox) => { + checkbox.checked = false; // 清空每个 checkbox 的选中状态 + }); + } + this.selectColor = []; + this.precisionmenu = []; + } + async _toggleCheckbox(this, event) { + const { batch, step } = this.selection; + const item = event.model.item; + const checkbox = this.shadowRoot?.getElementById(`checkbox-${event.model.index}`) as HTMLInputElement; + const run = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].run; + const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; + this.precisionmenu = []; + // 更新 selectColor 数组 + if (checkbox && checkbox.checked) { + this.selectColor.push(item[1].value); // 添加选中的颜色 + } else { + const index = this.selectColor.findIndex( + (color) => color[0] === item[1].value[0] && color[1] === item[1].value[1], + ); + if (index !== -1) { + this.selectColor.splice(index, 1); // 取消选中的颜色 + } + } + // 无selectColor,对应没有选择,全清0 + if (this.selectColor.length == 0) { + this.precisionmenu = []; + return; + } + const params = new URLSearchParams(); + params.set('precison', this.selectColor.join(',')); + if (run) params.set('run', run); + if (tag) params.set('tag', tag); + params.set('batch', String(batch === -1 ? -1 : batch - 1)); + params.set('step', String(step === -1 ? -1 : step - 1)); + const precisionPath = getRouter().pluginRouteForSrc('graph_ascend', '/precision', params); + try { + const precisionStr = tf_graph_parser.fetchPbTxt(precisionPath); + this.precisionmenu = JSON.parse(new TextDecoder().decode(await precisionStr).replace(/'/g, '"')) as object; + } catch (e) { + console.error('Parse tooltips failed, please check the format of tooltips in the input vis file'); + } + // 更新数据绑定 + this.notifyPath(`menu.${event.model.index}.checked`, checkbox.checked); + } + + download() { + this.fire('download-image-requested', this._downloadFilename); + } + _clearMicroStep() { + // 也清除一下MicroStep和Step + this.set('_selectedBatch', -1); + this.set('_selectedStep', -1); + } + computedLength(microsteps) { + return microsteps.length > 0 ? microsteps.length - 1 : 0; + } + _updateFileInput(e: Event) { + const file = (e.target as HTMLInputElement).files?.[0]; + if (!file) return; + // Strip off everything before the last "/" and strip off the file + // extension in order to get the name of the PNG for the graph. + let filePath = file.name; + const dotIndex = filePath.lastIndexOf('.'); + if (dotIndex >= 0) { + filePath = filePath.substring(0, dotIndex); + } + const lastSlashIndex = filePath.lastIndexOf('/'); + if (lastSlashIndex >= 0) { + filePath = filePath.substring(lastSlashIndex + 1); + } + this._setDownloadFilename(filePath); + this.set('selectedFile', e); + + tf_graph_util.notifyDebugEvent({ + actionId: tb_debug.GraphDebugEventId.UPLOADED_GRAPH_FROM_FILESYSTEM, + }); + } + _datasetsChanged(newDatasets: Dataset, oldDatasets: Dataset) { + if (oldDatasets !== null) { + // Select the first dataset by default. + this._selectedRunIndex = 0; + } + this._setDownloadFilename(this.datasets[this._selectedRunIndex]?.name); + } + _computeSelection( + datasets: Dataset, + _selectedRunIndex: number, + _selectedTagIndex: number, + _selectedGraphType: tf_graph_common.SelectionType, + _selectedBatch: number, + _selectedStep: number, + ) { + // _selectedStep = this.steplist[_selectedStep] + if (!datasets[_selectedRunIndex] || !datasets[_selectedRunIndex].tags[_selectedTagIndex]) { + return null; + } + return { + run: datasets[_selectedRunIndex].name, + tag: datasets[_selectedRunIndex].tags[_selectedTagIndex].tag, + type: _selectedGraphType, + batch: _selectedBatch, + step: _selectedStep, + }; + } + _selectedRunIndexChanged(runIndex: number) { + if (!this.datasets) return; + this._selectedTagIndex = 0; + this._selectedGraphType = this._getDefaultSelectionType(); + this.traceInputs = false; // Set trace input to off-state. + this._setDownloadFilename(this.datasets[runIndex]?.name); + } + _selectedTagIndexChanged(): void { + this._selectedGraphType = this._getDefaultSelectionType(); + } + _getDefaultSelectionType(): tf_graph_common.SelectionType { + const { datasets, _selectedRunIndex: run, _selectedTagIndex: tag } = this; + if (!datasets || !datasets[run] || !(datasets[run] as any).tags[tag] || (datasets[run] as any).tags[tag].opGraph) { + return tf_graph_common.SelectionType.OP_GRAPH; + } + const datasetForRun = datasets[run] as any; + if (datasetForRun.tags[tag].profile) { + return tf_graph_common.SelectionType.PROFILE; + } + if (datasetForRun.tags[tag].conceptualGraph) { + return tf_graph_common.SelectionType.CONCEPTUAL_GRAPH; + } + return tf_graph_common.SelectionType.OP_GRAPH; + } + _getFile() { + (this.$$('#file') as HTMLElement).click(); + } + _setDownloadFilename(name?: string) { + this._downloadFilename = (name || 'graph') + '.png'; + } + _statsNotNull(stats: tf_graph_proto.StepStats) { + return stats !== null; + } + _toggleLegendOpen(): void { + this.set('_legendOpened', !this._legendOpened); + } + _toggleColorsOpen(): void { + this.set('_Colors', !this._Colors); + } + _getToggleLegendIcon(legendOpened: boolean): string { + // This seems counter-intuitive, but actually makes sense because the + // expand-more button points downwards, and the expand-less button points + // upwards. For most collapsibles, this works because the collapsibles + // expand in the downwards direction. This collapsible expands upwards + // though, so we reverse the icons. + return legendOpened ? 'expand-more' : 'expand-less'; + } + _getSelectionOpGraphDisabled(datasets: Dataset, _selectedRunIndex: number, _selectedTagIndex: number) { + return ( + !datasets[_selectedRunIndex] || + !datasets[_selectedRunIndex].tags[_selectedTagIndex] || + !datasets[_selectedRunIndex].tags[_selectedTagIndex].opGraph + ); + } + _getSelectionProfileDisabled(datasets: Dataset, _selectedRunIndex: number, _selectedTagIndex: number) { + return ( + !datasets[_selectedRunIndex] || + !datasets[_selectedRunIndex].tags[_selectedTagIndex] || + !datasets[_selectedRunIndex].tags[_selectedTagIndex].profile + ); + } + _getSelectionConceptualGraphDisabled(datasets: Dataset, _selectedRunIndex: number, _selectedTagIndex: number) { + return ( + !datasets[_selectedRunIndex] || + !datasets[_selectedRunIndex].tags[_selectedTagIndex] || + !datasets[_selectedRunIndex].tags[_selectedTagIndex].conceptualGraph + ); + } + _getToggleIcon(expanded) { + return expanded ? 'expand-less' : 'expand-more'; + } + _toggleExpanded() { + this._expanded = !this._expanded; + } + triggerMenuExpandEvent(newName) { + const detailsElement = this.shadowRoot?.getElementById(newName) as HTMLDetailsElement; + if (detailsElement?.open) { + const event = new CustomEvent('menu-expand-node-changed', { + detail: { name: newName, open: 'unexpand' }, + }); + document.dispatchEvent(event); + } else { + const event = new CustomEvent('menu-expand-node-changed', { + detail: { name: newName, open: 'expand' }, + }); + document.dispatchEvent(event); + } + } + _getdata(this, event) { + const { batch, step } = this.selection; + var nodeName = event.currentTarget.getAttribute('id'); + if (nodeName !== 'root') { + this.triggerMenuExpandEvent(nodeName); + } + const run = this.datasets[this._selectedRunIndex].name; + const tag = this.datasets[this._selectedRunIndex].tags[this._selectedTagIndex].tag; + const params = new URLSearchParams(); + params.set('run', run); + if (tag) params.set('tag', tag); + params.set('node', nodeName); + params.set('batch', String(batch === -1 ? -1 : batch - 1)); + params.set('step', String(step === -1 ? -1 : step - 1)); + const subnode_list = getRouter().pluginRouteForSrc('graph_ascend', '/subgraph', params); + fetchPbTxt(subnode_list).then((arrayBuffer: ArrayBuffer) => { + parseGraphPbTxt(arrayBuffer).then((graphDef) => { + this.updateGraphData(graphDef, nodeName); + return graphDef; + }); + }); + } + _newMenu() { + let menubox; + menubox = this.shadowRoot?.getElementById('menubox'); + menubox.innerHTML = ''; + const summary = document.createElement('summary'); + const detail = document.createElement('details'); + summary.id = 'root'; + summary.textContent = '目录'; + summary.addEventListener('click', this._getdata.bind(this)); + detail.appendChild(summary); + menubox?.appendChild(detail); + } + _updateColorItems() { + const coloritems = this.shadowRoot?.getElementById('coloritems'); + const tbody = coloritems?.querySelector('tbody'); + if (Object.entries(this.colors).length !== 0) { + if (tbody) { + tbody.innerHTML = ''; + Object.entries(this.colors).forEach(([color, details]) => { + let detailsArray: any[] = []; + detailsArray = [details]; + if (detailsArray) { + const tr = document.createElement('tr'); + const td = document.createElement('td'); + const div = document.createElement('div'); + div.className = 'rectangle'; + div.style.backgroundColor = color; + const td2 = document.createElement('td'); + const td3 = document.createElement('td'); + const divInTd3 = document.createElement('div'); + const span = document.createElement('span'); + const paperTooltip = document.createElement('paper-tooltip'); + const divInPaperTooltip = document.createElement('div'); + divInTd3.className = 'legend-clarifier'; + span.textContent = '?'; + paperTooltip.setAttribute('animation-delay', '0'); + paperTooltip.setAttribute('position', 'right'); + paperTooltip.setAttribute('offset', '0'); + divInPaperTooltip.className = 'custom-tooltip'; + divInPaperTooltip.textContent = detailsArray[0].description; + if (detailsArray[0].value[0] !== undefined) { + td2.textContent = detailsArray[0].value[0] + '-' + detailsArray[0].value[1]; + } else { + td2.textContent = '无匹配节点'; + } + tbody.appendChild(tr); + tr.appendChild(td); + td.appendChild(div); + tr.appendChild(td2); + tr.appendChild(td3); + td3.appendChild(divInTd3); + divInTd3.appendChild(span); + divInTd3.appendChild(paperTooltip); + paperTooltip.appendChild(divInPaperTooltip); + } + }); + } + } else { + if (tbody) { + tbody.innerHTML = ''; + const rows = [ + { color: '#FFFCF3', text: '1-0.8' }, + { color: '#FFEDBE', text: '0.8-0.6' }, + { color: '#FFDC7F', text: '0.6-0.4' }, + { color: '#FFC62E', text: '0.4-0.2' }, + { color: '#ff704d', text: '0.2-0' }, + { color: '#C7C7C7', text: 'Not Connected' }, + ]; + rows.forEach(({ color, text }) => { + const tr = document.createElement('tr'); + tr.innerHTML = `
${text}`; + tbody.appendChild(tr); + }); + } + this.colorset = [ + [ + '#FFFCF3', + { + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + value: [0.8, 1], + }, + ], + [ + '#FFEDBE', + { + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + value: [0.6, 0.8], + }, + ], + [ + '#FFDC7F', + { + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + value: [0.4, 0.6], + }, + ], + [ + '#FFC62E', + { + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + value: [0.2, 0.4], + }, + ], + [ + '#ff704d', + { + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + value: [0, 0.2], + }, + ], + [ + '#C7C7C7', + { + description: '对比过程中节点未匹配上', + value: [], + }, + ], + ]; + } + } + updateGraphData(graphDef, nodeName) { + this.graphDef = graphDef; + let menubox; + let detailsElement; + if (nodeName == 'root') { + menubox = this.shadowRoot?.getElementById('menubox'); + detailsElement = menubox?.querySelector('details'); + } else { + detailsElement = this.shadowRoot?.getElementById(nodeName); + } + this.graphDef.node.forEach((node) => { + const detail = document.createElement('details'); + const summary = document.createElement('summary'); + let nodeNameWithoutPrefix; + detail.id = node.name; + summary.id = node.name; + if (!this.shadowRoot?.getElementById(node.name)) { + if (nodeName == 'root' && node.name.substring(0, 4) == 'N___') { + nodeNameWithoutPrefix = node.name.substring(4) + '(对比)'; + } else if (nodeName == 'root' && node.name.substring(0, 4) == 'B___') { + nodeNameWithoutPrefix = node.name.substring(4) + '(标杆)'; + } else { + if (node.name.substring(0, 4) == 'B___' || node.name.substring(0, 4) == 'N___') { + nodeNameWithoutPrefix = node.name.substring(4); + } else { + nodeNameWithoutPrefix = node.name; + } + } + + if (node.isLeaf) { + summary.classList.add('no-arrow'); + detail.style.paddingLeft = '22px'; + } else { + summary.classList.remove('no-arrow'); + } + summary.style.backgroundColor = 'white'; + if (Object.keys(this.colors).length == 0) { + this.colors = { + '#FFFCF3': { + value: [0.8, 1], + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + }, + '#FFEDBE': { + value: [0.6, 0.8], + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + }, + '#FFDC7F': { + value: [0.4, 0.6], + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + }, + '#FFC62E': { + value: [0.2, 0.4], + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + }, + '#ff704d': { + value: [0, 0.2], + description: + '此节点所有输入输出的统计量相对误差,值越大代表测量值与标杆值的偏差越大,相对误差计算方式:|(测量值-标杆值)/标杆值|', + }, + '#C7C7C7': { + value: [], + description: '对比过程中节点未匹配上', + }, + }; + } + for (const [color, details] of Object.entries(this.colors)) { + let detailsArray: any[] = []; + detailsArray = [details]; + const [start, end] = detailsArray[0].value; + if ( + (start == end && node.precision_index == start) || + (node.precision_index >= start && node.precision_index < end) || + (node.precision_index == end && end == 1) + ) { + summary.style.backgroundColor = color; + break; + } else { + summary.style.backgroundColor = '#C7C7C7'; + } + } + summary.textContent = nodeNameWithoutPrefix; + summary.addEventListener('click', this._getdata.bind(this)); + detail.appendChild(summary); + detailsElement?.appendChild(detail); + } + }); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_dashboard/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_dashboard/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..ea399c58cd2a366eb33958b4d2a61871f6ff8d34 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_dashboard/index.ts @@ -0,0 +1,418 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import { Canceller } from '../tf_backend/canceller'; +import { RequestManager } from '../tf_backend/requestManager'; +import { getRouter } from '../tf_backend/router'; +import '../tf_dashboard_common/tf-dashboard-layout'; +import * as tf_storage from '../tf_storage'; +import * as vz_sorting from '../vz_sorting/sorting'; +import '../tf_graph_board/tf-graph-board'; +import * as tf_graph_op from '../tf_graph_common/op'; +import * as tf_graph_render from '../tf_graph_common/render'; +import '../tf_graph_controls/tf-graph-controls'; +import '../tf_graph_loader/tf-graph-dashboard-loader'; + +/** + * The (string) name for the run of the selected dataset in the graph dashboard. + */ +const RUN_STORAGE_KEY = 'run'; +/** + * TODO(stephanwlee): Convert this to proper type when converting to TypeScript. + * @typedef {{ + * tag: ?string, + * displayName: string, + * conceptualGraph: boolean, + * opGraph: boolean, + * profile: boolean, + * }} + */ +const TagItem = {}; +/** + * TODO(stephanwlee): Convert this to proper type when converting to TypeScript. + * @typedef {{ + * name: string, + * tags: !Array, + * }} + */ +const RunItem = {}; + +/** + * tf-graph-dashboard displays a graph from a TensorFlow run. + * + * It has simple behavior: Creates a url-generator and run-generator + * to talk to the backend, and then passes the runsWithGraph (list of runs with + * associated graphs) along with the url generator into tf-graph-board for display. + * + * If there are multiple runs with graphs, the first run's graph is shown + * by default. The user can select a different run from a dropdown menu. + */ +@customElement('graph-app') +class TfGraphDashboard extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + +
+ +
+

No graph definition files were found.

+

+ To store a graph, create a + tf.summary.FileWriter + and pass the graph either via the constructor, or by calling its + add_graph() method. You may want to check out the + examining the TensorFlow graph tutorial. +

+ +

+ If you’re new to using TensorBoard, and want to find out how to add data and set up your event files, check + out the + README + and perhaps the + TensorBoard tutorial. +

+ +

+ If you think TensorBoard is configured properly, please see + the section of the README devoted to missing data problems + and consider filing an issue on GitHub. +

+
+
+ +
+
+
+ + `; + /** + * @type {!Array} + */ + @property({ + type: Array, + }) + _datasets: any[] = []; + @property({ + type: Boolean, + }) + _datasetsFetched: boolean = false; + @property({ + type: Number, + }) + _selectedDataset: number = 0; + @property({ type: Object, observer: '_renderHierarchyChanged' }) + _renderHierarchy: tf_graph_render.RenderGraphInfo; + @property({ + type: Object, + }) + _requestManager: RequestManager = new RequestManager(); + @property({ + type: Object, + }) + _canceller: Canceller = new Canceller(); + @property({ type: Boolean }) + _debuggerDataEnabled: boolean; + @property({ type: Boolean }) + allStepsModeEnabled: boolean; + @property({ + type: String, + notify: true, + }) + selectedNode: string; + @property({ type: Boolean }) + _isAttached: boolean; + // Whether this dashboard is initialized. This dashboard should only be initialized once. + @property({ type: Boolean }) + _initialized: boolean; + // An array of alerts (in chronological order) provided by debugging libraries on when bad + // values (NaN, +/- Inf) appear. + @property({ + type: Array, + notify: true, + }) + _debuggerNumericAlerts: unknown[] = []; + @property({ type: Array }) + runs: unknown[]; + @property({ + type: String, + notify: true, + observer: '_runObserver', + }) + run: string = tf_storage + .getStringInitializer(RUN_STORAGE_KEY, { + defaultValue: '', + useLocalStorage: false, + }) + .call(this); + @property({ + type: Object, + }) + _selection: object; + @property({ type: Object }) + _compatibilityProvider: object; + @property({ type: Boolean }) + _traceInputs: boolean; + @property({ type: Boolean }) + _autoExtractNodes: boolean; + @property({ type: Object }) + _selectedFile: any; + override attached() { + this.set('_isAttached', true); + } + override detached() { + this.set('_isAttached', false); + } + ready() { + super.ready(); + } + reload() { + if (!this._debuggerDataEnabled) { + // Check if the debugger plugin is enabled now. + this._requestManager.request(getRouter().pluginsListing()).then( + this._canceller.cancellable((result) => { + if (result.cancelled) { + return; + } + if (result.value['debugger']) { + // The debugger plugin is enabled. Request debugger-related + // data. Perhaps the debugger plugin had been disabled + // beforehand because no bad values (NaN, -/+ Inf) had been + // found and muted_if_healthy had been on. + this.set('_debuggerDataEnabled', true); + } + }), + ); + } + } + _fit() { + (this.$$('#graphboard') as any).fit(); + } + _onDownloadImageRequested(event: CustomEvent) { + (this.$$('#graphboard') as any).downloadAsImage(event.detail as string); + } + _getGraphDisplayClassName(_selectedFile: any, _datasets: any[]) { + const isDataValid = _selectedFile || _datasets.length; + return isDataValid ? '' : 'no-graph'; + } + _runObserver = tf_storage.getStringObserver(RUN_STORAGE_KEY, { + defaultValue: '', + polymerProperty: 'run', + useLocalStorage: false, + }); + _fetchDataset() { + return this._requestManager.request(getRouter().pluginRouteForSrc('graph_ascend', '/info')); + } + _fetchDebuggerNumericsAlerts() { + return this._requestManager.request(getRouter().pluginRouteForSrc('debugger', '/numerics_alert_report')); + } + @observe('_isAttached') + _maybeInitializeDashboard() { + var isAttached = this._isAttached; + if (this._initialized || !isAttached) { + // Either this dashboard is already initialized ... or we are not yet ready to initialize. + return; + } + this.set('_compatibilityProvider', new tf_graph_op.TpuCompatibilityProvider()); + // Set this to true so we only initialize once. + this._initialized = true; + this._fetchDataset().then((dataset) => { + const runNames = Object.keys(dataset); + // Transform raw data into UI friendly data. + this._datasets = runNames.sort(vz_sorting.compareTagNames).map((runName) => { + const runData = dataset[runName]; + const tagNames = Object.keys(runData.tags).sort(vz_sorting.compareTagNames); + const tags = tagNames + .map((name) => runData.tags[name]) + .map(({ tag, conceptual_graph, op_graph, profile }) => ({ + tag, + displayName: tag, + conceptualGraph: conceptual_graph, + opGraph: op_graph, + profile, + })); + // Translate a run-wide GraphDef into specially named (without a tag) op graph + // to abstract the difference between run_graph vs. op_graph from other + // components. + const tagsWithV1Graph = runData.run_graph + ? [ + { + tag: null, + displayName: 'Default', + conceptualGraph: false, + opGraph: true, + profile: false, + }, + ...tags, + ] + : tags; + return { name: runName, tags: tagsWithV1Graph }; + }); + this._datasetsFetched = true; + }); + } + @observe('_datasetsFetched', '_datasets', 'run') + _determineSelectedDataset() { + var datasetsFetched = this._datasetsFetched; + var datasets = this._datasets; + var run = this.run; + // By default, load the first dataset. + if (!run) { + // By default, load the first dataset. + this.set('_selectedDataset', 0); + return; + } + // If the URL specifies a dataset, load it. + const dataset = datasets.findIndex((d) => d.name === run); + if (dataset === -1) { + if (datasetsFetched) { + // Tell the user if the dataset cannot be found to avoid misleading + // the user. + const dialog = this.$$('#error-dialog') as any; + dialog.textContent = `No dataset named "${run}" could be found.`; + dialog.open(); + } + return; + } + this.set('_selectedDataset', dataset); + } + @observe('_datasetsFetched', '_datasets', '_selectedDataset') + _updateSelectedDatasetName() { + var datasetsFetched = this._datasetsFetched; + var datasets = this._datasets; + var selectedDataset = this._selectedDataset; + if (!datasetsFetched) return; + // Cannot update `run` to update the hash in case datasets for graph is empty. + if (datasets.length <= selectedDataset) return; + this.set('run', datasets[selectedDataset].name); + } + _datasetsState(datasetsFetched, datasets, state) { + if (!datasetsFetched) return state === 'NOT_LOADED'; + if (!datasets || !datasets.length) return state === 'EMPTY'; + return state === 'PRESENT'; + } + _renderHierarchyChanged(renderHierarchy) { + // Reload any data on the graph when the render hierarchy (which determines which nodes are + // rendered) changes. + this.reload(); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_debugger_data_card/tf-graph-debugger-data-card.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_debugger_data_card/tf-graph-debugger-data-card.ts new file mode 100644 index 0000000000000000000000000000000000000000..fe28d1db811b6a2157fc72f3187e2dd9839d2ba4 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_debugger_data_card/tf-graph-debugger-data-card.ts @@ -0,0 +1,533 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { computed, customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as PolymerDom from '../polymer/dom'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tf_graph_render from '../tf_graph_common/render'; +import * as tf_graph_scene from '../tf_graph_common/scene'; +import * as tf_graph_util from '../tf_graph_common/util'; + +@customElement('tf-graph-debugger-data-card') +class TfGraphDebuggerDataCard extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + +
Enable all (not just sampled) steps. Requires slow disk read.
+ +

+ Step of Health Pills: + + + +

+ + +

+ Health Pill + + +

+ +
+

Numeric Alerts

+

Alerts are sorted from top to bottom by increasing timestamp.

+
+ + + + + + + + + +
First OffenseTensor (Device)Event Counts
+
+
+ +
+ `; + @property({ type: Object }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + @property({ + type: Array, + notify: true, + }) + debuggerNumericAlerts: any; + @property({ type: Object }) + nodeNamesToHealthPills: any; + @property({ + type: Number, + notify: true, + }) + healthPillStepIndex: any; + // Only relevant if we are in all steps mode, in which case the user may want to view health + // pills for a specific step. + @property({ + type: Number, + notify: true, + }) + specificHealthPillStep: number = 0; + // Two-ways + @property({ + type: String, + notify: true, + }) + selectedNode: any; + @property({ + type: String, + notify: true, + }) + highlightedNode: any; + // The enum value of the include property of the selected node. + @property({ + type: Number, + notify: true, + }) + selectedNodeInclude: any; + // Whether health pills are currently being loaded, in which case we show a spinner (and the + // current health pills shown might be out of date). + @property({ type: Boolean }) + areHealthPillsLoading: any; + @property({ + type: Array, + }) + healthPillEntries: unknown[] = tf_graph_scene.healthPillEntries; + // When all-steps mode is enabled, the user can request health pills for any step. In this + // mode, Tensorboard makes a request every time the user drags the slider to a different step. + @property({ + type: Boolean, + notify: true, + }) + allStepsModeEnabled: any; + ready() { + super.ready(); + var mainContainer = document.getElementById('mainContainer'); + var scrollbarContainer = document.querySelector('tf-dashboard-layout .scrollbar') as HTMLElement | null; + if (mainContainer && scrollbarContainer) { + // If this component is being used inside of TensorBoard's dashboard layout, it may easily + // cause the dashboard layout element to overflow, giving the user 2 scroll bars. Prevent + // that by hiding whatever content overflows - the user will have to expand the viewport to + // use this debugging card. + mainContainer.style.overflow = 'hidden'; + scrollbarContainer.style.overflow = 'hidden'; + } + } + _healthPillsAvailable(debuggerDataEnabled: any, nodeNamesToHealthPills: any) { + // So long as there is a mapping (even if empty) from node name to health pills, show the + // legend and slider. We do that because, even if no health pills exist at the current step, + // the user may desire to change steps, and the slider must show for the user to do that. + return debuggerDataEnabled && nodeNamesToHealthPills; + } + _computeTensorCountString(healthPillValuesForSelectedNode: any, valueIndex: any) { + if (!healthPillValuesForSelectedNode) { + // No health pill data is available. + return ''; + } + return healthPillValuesForSelectedNode[valueIndex].toFixed(0); + } + @computed( + 'nodeNamesToHealthPills', + 'healthPillStepIndex', + 'selectedNode', + 'allStepsModeEnabled', + 'areHealthPillsLoading', + ) + get healthPillValuesForSelectedNode(): unknown[] | null { + var nodeNamesToHealthPills = this.nodeNamesToHealthPills; + var healthPillStepIndex = this.healthPillStepIndex; + var selectedNode = this.selectedNode; + var allStepsModeEnabled = this.allStepsModeEnabled; + var areHealthPillsLoading = this.areHealthPillsLoading; + if (areHealthPillsLoading) { + // Health pills are loading. Do not render data that is out of date. + return null; + } + if (!selectedNode) { + // No node is selected. + return null; + } + const healthPills = nodeNamesToHealthPills[selectedNode]; + if (!healthPills) { + // This node lacks a health pill. + return null; + } + // If all steps mode is enabled, we use the first health pill in the list because the JSON + // response from the server is a mapping between node name and a list of 1 health pill. + const healthPill = healthPills[allStepsModeEnabled ? 0 : healthPillStepIndex]; + if (!healthPill) { + // This node lacks a health pill at the current step. + return null; + } + // The health pill count values start at 2. Each health pill contains 6 values. + return healthPill.value.slice(2, 8); + } + @computed( + 'nodeNamesToHealthPills', + 'healthPillStepIndex', + 'allStepsModeEnabled', + 'specificHealthPillStep', + 'areHealthPillsLoading', + ) + get _currentStepDisplayValue(): any { + var nodeNamesToHealthPills = this.nodeNamesToHealthPills; + var healthPillStepIndex = this.healthPillStepIndex; + var allStepsModeEnabled = this.allStepsModeEnabled; + var specificHealthPillStep = this.specificHealthPillStep; + var areHealthPillsLoading = this.areHealthPillsLoading; + if (allStepsModeEnabled) { + // The user seeks health pills for specific step from the server. + return specificHealthPillStep.toFixed(0); + } + if (areHealthPillsLoading) { + // The current step is undefined. + return 0; + } + for (let nodeName in nodeNamesToHealthPills) { + // All nodes have the same number of steps stored, so only examine 1 node. We cannot + // directly index into the nodeNamesToHealthPills object because we do not have a key. + // If all steps mode is enabled, we only have 1 step to show. + return nodeNamesToHealthPills[nodeName][healthPillStepIndex].step.toFixed(0); + } + // The current step could not be computed. + return 0; + } + // The biggest step value ever seen. Used to determine what steps of health pills to let the + // user fetch in all steps mode. + @computed('nodeNamesToHealthPills') + get _biggestStepEverSeen(): number { + var nodeNamesToHealthPills = this.nodeNamesToHealthPills; + for (let nodeName in nodeNamesToHealthPills) { + // All nodes have the same number of steps stored, so only examine 1 node. + // The index is 1 less than the count. Tensorboard backend logic guarantees that the length + // of the array will be greater than 1. + var healthPills = nodeNamesToHealthPills[nodeName]; + return Math.max(this._biggestStepEverSeen, healthPills[healthPills.length - 1].step); + } + // No steps seen so far. Default to 0. + return this._biggestStepEverSeen || 0; + } + @computed('nodeNamesToHealthPills') + get _maxStepIndex(): number { + var nodeNamesToHealthPills = this.nodeNamesToHealthPills; + for (let nodeName in nodeNamesToHealthPills) { + // All nodes have the same number of steps stored, so only examine 1 node. + // The index is 1 less than the count. Tensorboard backend logic guarantees that the length + // of the array will be greater than 1. + return nodeNamesToHealthPills[nodeName].length - 1; + } + // Return a falsy value. The slider should be hidden. + return 0; + } + _hasDebuggerNumericAlerts(debuggerNumericAlerts: any) { + return debuggerNumericAlerts && debuggerNumericAlerts.length; + } + @observe('debuggerNumericAlerts') + _updateAlertsList() { + var debuggerNumericAlerts = this.debuggerNumericAlerts; + var alertBody = this.$$('#numeric-alerts-body'); + if (!alertBody) { + return; + } + (alertBody as HTMLElement).innerText = ''; + for (var i = 0; i < debuggerNumericAlerts.length; i++) { + var alert = debuggerNumericAlerts[i]; + var tableRow = document.createElement('tr'); + var timestampTd = document.createElement('td'); + timestampTd.innerText = tf_graph_util.computeHumanFriendlyTime(alert.first_timestamp); + timestampTd.classList.add('first-offense-td'); + tableRow.appendChild(timestampTd); + var tensorDeviceTd = document.createElement('td'); + tensorDeviceTd.classList.add('tensor-device-td'); + var tensorSection = document.createElement('div'); + tensorSection.classList.add('tensor-section-within-table'); + tensorSection.innerText = alert.tensor_name; + this._addOpExpansionListener(tensorSection, alert.tensor_name); + tensorDeviceTd.appendChild(tensorSection); + var deviceSection = document.createElement('div'); + deviceSection.classList.add('device-section-within-table'); + deviceSection.innerText = '(' + alert.device_name + ')'; + tensorDeviceTd.appendChild(deviceSection); + tableRow.appendChild(tensorDeviceTd); + var miniHealthPill = document.createElement('div'); + miniHealthPill.classList.add('mini-health-pill'); + var miniHealthPillTd = document.createElement('td'); + miniHealthPillTd.classList.add('mini-health-pill-td'); + miniHealthPillTd.appendChild(miniHealthPill); + tableRow.appendChild(miniHealthPillTd); + if (alert.neg_inf_event_count) { + var negativeInfCountSection = document.createElement('div'); + negativeInfCountSection.classList.add('negative-inf-mini-health-pill-section'); + negativeInfCountSection.innerText = alert.neg_inf_event_count; + negativeInfCountSection.setAttribute('title', alert.neg_inf_event_count + ' events with -\u221E'); + miniHealthPill.appendChild(negativeInfCountSection); + } + if (alert.pos_inf_event_count) { + var positiveInfCountSection = document.createElement('div'); + positiveInfCountSection.classList.add('positive-inf-mini-health-pill-section'); + positiveInfCountSection.innerText = alert.pos_inf_event_count; + positiveInfCountSection.setAttribute('title', alert.pos_inf_event_count + ' events with +\u221E'); + miniHealthPill.appendChild(positiveInfCountSection); + } + if (alert.nan_event_count) { + var nanCountSection = document.createElement('div'); + nanCountSection.classList.add('nan-mini-health-pill-section'); + nanCountSection.innerText = alert.nan_event_count; + nanCountSection.setAttribute('title', alert.nan_event_count + ' events with NaN'); + miniHealthPill.appendChild(nanCountSection); + } + (PolymerDom.dom(alertBody) as any).appendChild(tableRow); + } + } + // Adds a listener to an element, so that when that element is clicked, the tensor with + // tensorName expands. + _addOpExpansionListener(clickableElement: any, tensorName: any) { + clickableElement.addEventListener('click', () => { + // When the user clicks on a tensor name, expand all nodes until the user can see the + // associated node. + var nameOfNodeToSelect = tf_graph_render.expandUntilNodeIsShown( + document.getElementById('scene'), + this.renderHierarchy, + tensorName, + ); + // Store the current scroll of the graph info card. Node selection alters that scroll, and + // we restore the scroll later. + var previousScrollFromBottom: any; + var graphInfoCard = document.querySelector('tf-graph-info#graph-info') as any; + if (graphInfoCard) { + previousScrollFromBottom = graphInfoCard.scrollHeight - graphInfoCard.scrollTop; + } + // Update the selected node within graph logic. + var previousSelectedNode = this.selectedNode; + this.set('selectedNode', nameOfNodeToSelect); + // Scroll the graph info card back down if necessary so that user can see the alerts section + // again. Selecting the node causes the info card to scroll to the top, which may mean the + // user no longer sees the list of alerts. + var scrollToOriginalLocation = () => { + graphInfoCard.scrollTop = graphInfoCard.scrollHeight - previousScrollFromBottom; + }; + if (graphInfoCard) { + // This component is used within an info card. Restore the original scroll. + if (previousSelectedNode) { + // The card for the selected node has already opened. Immediately restore the scroll. + scrollToOriginalLocation(); + } else { + // Give some time for the DOM of the info card to be created before scrolling down. + window.setTimeout(scrollToOriginalLocation, 20); + } + } + }); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/node-ioinfo-item.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/node-ioinfo-item.ts new file mode 100644 index 0000000000000000000000000000000000000000..646c3e75a03cfc335b67fab65c4b066167734993 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/node-ioinfo-item.ts @@ -0,0 +1,155 @@ +/* Copyright (c) 2024, Huawei Technologies. + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { computed, customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as _ from 'lodash'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; + +interface DisplayAttr { + key: string; + tooltip: string; + value: any; + className: string; +} +@customElement('node-ioinfo-item') +class NodeIoInfoItem extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + +
+ + +
+ [[name]] + +
+
+ `; + @property({ + type: String, + }) + name: string = ''; + @property({ + type: Object, + }) + attrs: Array<{ key: string; tooltip: string; value: any }> = []; + @property({ + type: Boolean, + }) + _attrExpanded: boolean = true; + @computed('attrs') + get _displayAttrs(): DisplayAttr[] { + const errorItem = this.attrs.find((item) => item.key === 'error_key'); + if (errorItem && Array.isArray(errorItem.value)) { + const errorKeys = errorItem.value as string[]; + const tempAttrs: DisplayAttr[] = []; + _.each( + this.attrs.filter((attr) => attr.key !== 'error_key'), + (item) => { + if (errorKeys.includes(item.key)) { + tempAttrs.push({ + ...item, + className: 'attr-list-item attr-error', + }); + } else { + tempAttrs.push({ + ...item, + className: 'attr-list-item', + }); + } + }, + ); + return tempAttrs; + } else { + return _.map(this.attrs, (item) => { + return { + ...item, + className: 'attr-list-item', + }; + }); + } + } + _getToggleIcon(expanded) { + return expanded ? 'expand-less' : 'expand-more'; + } + _toggleInputsDataExpanded() { + this._attrExpanded = !this._attrExpanded; + this.fire('io-item-toggle'); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-graph-info.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-graph-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..c4911e5894bf5face0fba71b7f8582128df6e4cd --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-graph-info.ts @@ -0,0 +1,131 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tf_graph from '../tf_graph_common/graph'; +import { Hierarchy } from '../tf_graph_common/hierarchy'; +import * as tf_graph_render from '../tf_graph_common/render'; +import '../tf_graph_debugger_data_card/tf-graph-debugger-data-card'; +import '../tf_graph_op_compat_card/tf-graph-op-compat-card'; +import './tf-node-info'; + +@customElement('tf-graph-info') +class TfGraphInfo extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + `; + @property({ type: String }) + override title: string; + @property({ type: Object }) + graphHierarchy: Hierarchy; + @property({ type: Object }) + graph: tf_graph.SlimGraph; + @property({ type: Object }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + @property({ type: String }) + compatNodeTitle: string; + // Two-ways + @property({ + type: String, + notify: true, + }) + selectedNode: string; + /** @type {string?} */ + @property({ + type: String, + notify: true, + }) + highlightedNode: string; + @property({ + type: Object, + notify: true, + }) + selectedEdge: tf_graph_render.EdgeData; + // The enum value of the include property of the selected node. + @property({ + type: Number, + notify: true, + }) + selectedNodeInclude: number; + // Whether debugger data is enabled for this instance of Tensorboard. + @property({ type: Boolean }) + debuggerDataEnabled: boolean; + @property({ type: Object }) + tooltips: object; + + ready() { + super.ready(); + + this.addEventListener( + 'node-list-item-click', + this._nodeListItemClicked.bind(this) + ); + this.addEventListener( + 'node-list-item-mouseover', + this._nodeListItemMouseover.bind(this) + ); + this.addEventListener( + 'node-list-item-mouseout', + this._nodeListItemMouseout.bind(this) + ); + } + _nodeListItemClicked(event) { + this.selectedNode = event.detail.nodeName; + this.selectedEdge = null!; + } + _nodeListItemMouseover(event) { + this.highlightedNode = event.detail.nodeName; + } + _nodeListItemMouseout() { + this.highlightedNode = null!; + } + _checkSelected(node, edge) { + return (node && this.graphHierarchy.node(node)) || !!edge; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-node-info.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-node-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..a5666a4c89a5628aa4bf874a547654b5f7264406 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-node-info.ts @@ -0,0 +1,990 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { computed, customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as _ from 'lodash'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import '../tf_wbr_string/tf-wbr-string'; +import * as tf_graph_scene_edge from '../tf_graph_common/edge'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; +import * as tf_graph_scene_node from '../tf_graph_common/node'; +import * as tf_graph_render from '../tf_graph_common/render'; +import '../tf_graph_common/tf-node-icon'; +import * as tf_graph_util from '../tf_graph_common/util'; +import './tf-node-list-item'; +import './node-ioinfo-item'; + +let matchStorage: [string, number, any[], any[]][] = []; +let unMatchStorage: [] = []; +@customElement('tf-node-info') +class TfNodeInfo extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + + +
+ + +
+ +
+
+
+ + + +
+
+
+ + + + `; + /** + * Note: we intentionally avoid the property name 'nodeName', because + * Polymer Resin does not support it. Resin's contract system prevents + * using native HTMLElement property names unless they have an + * explicit security contract (e.g. 'title' is allowed). + * https://github.com/Polymer/polymer-resin/blob/master/lib/contracts/contracts.js + */ + @property({ type: String }) + selectedNode: string; + @property({ type: Object }) + graphHierarchy: tf_graph_hierarchy.Hierarchy; + @property({ type: Object }) + renderHierarchy: any; + @property({ + type: Object, + computed: '_getNode(selectedNode, graphHierarchy)', + observer: '_resetState', + }) + _node: any; + @property({ + type: Object, + computed: '_getNodeStats(selectedNode, graphHierarchy)', + observer: '_resetState', + }) + _nodeStats: any; + @property({ + type: Number, + observer: '_nodeIncludeStateChanged', + }) + // The enum value of the include property of the selected node. + nodeInclude: number; + @property({ + type: Boolean, + }) + _expanded: boolean = true; + @property({ + type: Boolean, + }) + _suggestExpanded: boolean = true; + @property({ + type: Boolean, + }) + _attrExpanded: boolean = true; + @property({ + type: Boolean, + }) + _inputsExpanded: boolean = true; + @property({ + type: Boolean, + }) + _outputsExpanded: boolean = true; + @property({ + type: Boolean, + }) + _stackExpanded: boolean = true; + @property({ + type: Boolean, + }) + _openedControlPred: boolean = false; + @property({ + type: Boolean, + }) + _openedControlSucc: boolean = false; + @property({ type: String }) + _auxButtonText: string; + @property({ type: String }) + _groupButtonText: string; + @property({ type: Object }) + _templateIndex: (name: string) => number | null = null!; + @property({ type: Object }) + selectedEdge: tf_graph_render.EdgeData; + @property({ type: Object }) + tooltips: object; + expandNode() { + this.fire('_node.expand', (this as any).node); + } + _getNode(selectedNode, graphHierarchy) { + return graphHierarchy.node(selectedNode); + } + @observe('_node') + changeSuggestion() { + let node = this._node; + if (node && node.suggestions && node.suggestions.text) { + let text: string = node.suggestions.text; + _.each(node.suggestions, (value, key) => { + if (key !== 'text') { + text = text.replace(new RegExp(key, 'g'), `${key}`); + } + }); + setTimeout(() => { + let textarea = this.shadowRoot?.getElementById('suggestion-content'); + if (textarea) { + textarea.innerHTML = text; + } + }); + } + } + _getNodeStats(selectedNode, graphHierarchy) { + var node = this._getNode(selectedNode, graphHierarchy); + if (node) { + return node.stats; + } + return null; + } + _getTotalMicros(stats) { + return stats ? stats.getTotalMicros() : 0; + } + @computed('_nodeStats') + get _hasDisplayableNodeStats(): boolean { + var stats = this._nodeStats; + return tf_graph_util.hasDisplayableNodeStats(stats); + } + @computed('_node') + get _suggestions(): object | undefined { + const node = this._node; + if (!node || !node.suggestions || !node.suggestions.text) { + return undefined; + } + + return node.suggestions; + } + @computed('_nodeStats') + get _nodeStatsFormattedBytes(): string | undefined { + var stats = this._nodeStats; + if (!stats || !stats.totalBytes) { + return; + } + return tf_graph_util.convertUnitsToHumanReadable(stats.totalBytes, tf_graph_util.MEMORY_UNITS); + } + @computed('_nodeStats') + get _nodeStatsFormattedComputeTime(): string | undefined { + var stats = this._nodeStats; + if (!stats || !stats.getTotalMicros()) { + return; + } + return tf_graph_util.convertUnitsToHumanReadable(stats.getTotalMicros(), tf_graph_util.TIME_UNITS); + } + @computed('_nodeStats') + get _nodeStatsFormattedOutputSizes(): unknown[] | undefined { + var stats = this._nodeStats; + if (!stats || !stats.outputSize || !stats.outputSize.length) { + return; + } + return _.map(stats.outputSize, function (shape) { + if (shape.length === 0) { + return 'scalar'; + } + return '[' + shape.join(', ') + ']'; + }); + } + _getRenderInfo(selectedNode, renderHierarchy) { + return this.renderHierarchy.getOrCreateRenderNodeByName(selectedNode); + } + _getNodeName(selectedEdge, _node) { + if (!!selectedEdge) { + return 'Edge'; + } else { + if (_node.name.startsWith('N___') || _node.name.startsWith('B___')) { + return _node.name.slice(4); + } else { + return _node.name; + } + } + } + _objectToArray(obj: object): Array<{ key: string; value: any }> { + return Object.keys(obj) + .filter((key) => key !== 'data_name') + .map((key) => { + let tooltip = key; + if (tooltip in this.tooltips) { + tooltip += `: ${this.tooltips[key]}`; + } + return { + key, + tooltip, + value: obj[key], + }; + }); + } + @computed('_node', 'tooltips') + get _inputsData(): Array<{ key: string; value: any }> { + const node = this._node; + let nodeName = ''; + if (!node || !node.inputData) { + return []; + } + if (this._node.name.startsWith('N___') || this._node.name.startsWith('B___')) { + nodeName = this._node.name.slice(4); + } else { + nodeName = this._node.name; + } + // 深拷贝 node.inputData,以免直接修改原始数据 + const inputDataCopy = _.cloneDeep(node.inputData); + const foundMatchItem = matchStorage.find( + (item) => Array.isArray(item) && item.length > 0 && item[0] === this._node.name, + ); + if (foundMatchItem) { + let counter = 0; + for (let key of Object.keys(inputDataCopy)) { + if (this._node.name.startsWith('N___')) { + const item = inputDataCopy[key]; + if (item.hasOwnProperty('Norm') && matchStorage.length !== 0 && foundMatchItem[2].length !== 0) { + for (let [_key, value] of foundMatchItem[2][counter].entries()) { + item[value[0]] = value[1]; + } + } + } + counter++; + if (counter === foundMatchItem[2].length) { + break; + } + } + } + // 使用 inputDataCopy 生成输出,而不影响原始的 node.outputData + return Object.keys(inputDataCopy).map((key, _index) => { + const uniqueKey = `${key.replace(`${nodeName}.`, '')}`; + return { + key: uniqueKey, + value: this._objectToArray(inputDataCopy[key]), + }; + }); + } + + @computed('_node', 'tooltips') + get _outputsData(): Array<{ key: string; value: any }> { + const node = this._node; + let nodeName = ''; + if (node !== undefined) { + if (this._node.name.startsWith('N___') || this._node.name.startsWith('B___')) { + nodeName = this._node.name.slice(4); + } else { + nodeName = this._node.name; + } + } + if (!node || !node.outputData) { + return []; + } + // 深拷贝 node.outputData,以免直接修改原始数据 + const outputDataCopy = _.cloneDeep(node.outputData); + const foundMatchItem = matchStorage.find( + (item) => Array.isArray(item) && item.length > 0 && item[0] === this._node.name, + ); + if (foundMatchItem) { + let counter = 0; + for (let key of Object.keys(outputDataCopy)) { + if (this._node.name.startsWith('N___')) { + const item = outputDataCopy[key]; + if (item.hasOwnProperty('Norm') && matchStorage.length !== 0 && foundMatchItem[3].length !== 0) { + for (let [_key, value] of foundMatchItem[3][counter].entries()) { + item[value[0]] = value[1]; + } + } + } + counter++; + if (counter === foundMatchItem[3].length) { + break; + } + } + } + // 使用 outputDataCopy 生成输出,而不影响原始的 node.outputData + return Object.keys(outputDataCopy).map((key, _index) => { + const uniqueKey = `${key.replace(`${nodeName}.`, '')}`; + return { + key: uniqueKey, + value: this._objectToArray(outputDataCopy[key]), + }; + }); + } + + @computed('_node', 'tooltips') + get _stackData() { + const node = this._node; + if (!node || !node.stackData) { + return []; + } + let newStackData: string = node.stackData.slice(1, -1); + let segments = newStackData.split(", '"); + segments = segments.map((segment) => `${segment.replace(/^'/, '').replace(/'$/, '')}`); + return segments; + } + @computed('_node', 'selectedEdge', 'tooltips') + get _attributes(): unknown[] { + if (!!this.selectedEdge) { + const edge = this.selectedEdge?.label.metaedge.baseEdgeList[0]; + if (!edge || !edge.attr) { + return []; + } + var attrs: any[] = []; + _.each(Object.keys(edge.attr), (key) => { + if (key !== '_path') { + attrs.push({ + key, + value: JSON.stringify(edge.attr?.[key]), + }); + } + }); + } else { + var node = this._node; + this.async(this._resizeList.bind(this, '#attributesList')); + if (!node || !node.attr) { + return []; + } + var attrs: any[] = []; + const tooltips = this.tooltips; + function addAttributeToAttrs(entry, attrs, tooltips) { + if (entry.key !== tf_graph.PRECISION_INDEX) { + let tooltip = entry.key; + if (entry.key in tooltips) { + tooltip += `: ${tooltips[entry.key]}`; + } + attrs.push({ + key: entry.key, + tooltip, + value: processValue(entry.value), + }); + } + } + + function processValue(value) { + if (typeof value === 'object' && !Array.isArray(value)) { + // 安全的 JSON 处理:直接将对象转为 JSON 字符串 + return JSON.stringify(value, (key, val) => { + return val; + }); + } + + if (Array.isArray(value) && value.every((item) => !Array.isArray(item))) { + return JSON.stringify(value); // 一维数组直接转为 JSON 字符串 + } + + if (Array.isArray(value) && value.some(Array.isArray)) { + return value + .map((arr) => + [ + `Name: ${arr[0] || ''}\n`, + `Type: ${arr[1] || ''}\n`, + `A Core: ${arr[2] || ''}\n`, + `Duration(US): ${arr[3] || ''}\n`, + `Input Shapes: ${arr[4] || ''}\n`, + `Input Data Types: ${arr[5] || ''}\n`, + ].join('\n'), + ) + .join('\n'); // 将每个数组项的内容合并成一个长字符串 + } + + return typeof value === 'string' ? value : JSON.stringify(value); + } + + if (node.name.startsWith('N___')) { + _.each(node.attr, (entry) => { + addAttributeToAttrs(entry, attrs, tooltips); + }); + } else { + _.each(node.attr, (entry) => { + addAttributeToAttrs(entry, attrs, tooltips); + }); + } + } + return attrs; + } + @computed('_node') + get _device(): string { + var node = this._node; + return node ? node.device : null; + } + @computed('_node', 'graphHierarchy', 'selectedEdge') + get _successors(): any { + var node = this._node; + var hierarchy = this.graphHierarchy; + this._refreshNodeItemList('inputsList'); + if (node) { + return this._convertEdgeListToEdgeInfoList(hierarchy.getSuccessors(node.name), false, node.isGroupNode); + } else if (this.selectedEdge) { + const successor: tf_graph.Edges = { + control: [], + regular: [this.selectedEdge.label.metaedge], + }; + return this._convertEdgeListToEdgeInfoList(successor, false, false); + } else { + return { regular: [], control: [] }; + } + } + @computed('_node', 'graphHierarchy', 'selectedEdge') + get _predecessors(): any { + var node = this._node; + var hierarchy = this.graphHierarchy; + this._refreshNodeItemList('outputsList'); + if (node) { + return this._convertEdgeListToEdgeInfoList(hierarchy.getPredecessors(node.name), true, node.isGroupNode); + } else if (this.selectedEdge) { + const predecessor: tf_graph.Edges = { + control: [], + regular: [this.selectedEdge.label.metaedge], + }; + return this._convertEdgeListToEdgeInfoList(predecessor, true, false); + } else { + return { regular: [], control: [] }; + } + } + // The iron lists that enumerate ops must be asynchronously updated when + // the data they render changes. This function triggers that update. + _refreshNodeItemList(nodeListId) { + this.async(this._resizeList.bind(this, `#${nodeListId}`)); + } + _onInputsDataItemToggled() { + this.async(this._resizeList.bind(this, '#inputsList')); + } + _onOutputsDataItemToggled() { + this.async(this._resizeList.bind(this, '#outputsList')); + } + _convertEdgeListToEdgeInfoList(list, isPredecessor, isGroupNode) { + var addParentNodes = (path: string, parentNodeList: string[]) => { + const nodeNameList = path.split(tf_graph.NAMESPACE_DELIM); + nodeNameList.pop(); + _.each(nodeNameList, (nodeName) => { + if (!parentNodeList.includes(nodeName)) { + parentNodeList.push(nodeName); + } + }); + }; + /** + * Unpacks the metaedge into a list of base edge information + * that can be rendered. + */ + var unpackMetaedge = (metaedge, edgeKeyList: number[], parentNodeList: string[]) => { + return _.map(metaedge.baseEdgeList, (baseEdge: tf_graph.BaseEdge) => { + var name = isPredecessor ? baseEdge.v : baseEdge.w; + name = name?.split(tf_graph.NAMESPACE_DELIM).pop(); + if (name && parentNodeList.includes(name)) { + return undefined; + } + let edgeId = baseEdge.attr?.id; + if (typeof edgeId === 'number') { + if (edgeKeyList.includes(edgeId)) { + return undefined; + } else { + edgeKeyList.push(edgeId); + } + } + return { + name: name, + node: this._getNode(name, this.graphHierarchy), + edgeLabel: tf_graph_scene_edge.getLabelForBaseEdge(baseEdge, this.renderHierarchy), + renderInfo: this._getRenderInfo(name, this.renderHierarchy), + }; + }); + }; + /** + * Converts a list of metaedges to a list of edge information + * that can be rendered. + */ + var toEdgeInfoList = function (edges) { + var edgeInfoList: any[] = []; + let edgeIdList: number[] = []; + let parentNodeList: string[] = []; + _.each(edges, (metaedge) => { + const metaName = isPredecessor ? metaedge.v : metaedge.w; + metaName && addParentNodes(metaName, parentNodeList); + _.each(metaedge.baseEdgeList, (baseEdge: tf_graph.BaseEdge) => { + const name = isPredecessor ? baseEdge.v : baseEdge.w; + name && addParentNodes(name, parentNodeList); + let nodePath = baseEdge.attr?._path; + if (nodePath) { + addParentNodes(nodePath, parentNodeList); + } + }); + }); + _.each(edges, (metaedge) => { + var name: string = isPredecessor ? metaedge.v : metaedge.w; + addParentNodes(name, parentNodeList); + // Enumerate all the base edges if the node is an OpNode, or the + // metaedge has only 1 edge in it. + edgeInfoList = edgeInfoList.concat(unpackMetaedge(metaedge, edgeIdList, parentNodeList)); + }); + edgeInfoList = edgeInfoList.filter((item) => !!item && !parentNodeList.includes(item.name)); + return edgeInfoList; + }.bind(this); + return { + regular: toEdgeInfoList(list.regular), + control: toEdgeInfoList(list.control), + }; + } + @computed('_node') + get _subnodes(): unknown[] { + var node = this._node; + return node && node.metagraph ? node.metagraph.nodes() : null; + } + @computed('_node', 'selectedEdge') + get _totalPredecessors(): number { + if (!!this.selectedEdge) { + return 0; + } else { + return Object.keys(this._node?.inputData ?? {}).length; + } + } + @computed('_node', 'selectedEdge') + get _totalSuccessors(): number { + if (!!this.selectedEdge) { + return 0; + } else { + return Object.keys(this._node?.outputData ?? {}).length; + } + } + _toggleControlPred() { + this._openedControlPred = !this._openedControlPred; + } + _toggleControlSucc() { + this._openedControlSucc = !this._openedControlSucc; + } + _toggleExpanded() { + this._expanded = !this._expanded; + } + _toggleSuggestExpanded() { + this._suggestExpanded = !this._suggestExpanded; + } + _toggleAttrExpanded() { + this._attrExpanded = !this._attrExpanded; + } + _toggleInputsExpanded() { + this._inputsExpanded = !this._inputsExpanded; + } + _toggleOutputsExpanded() { + this._outputsExpanded = !this._outputsExpanded; + } + _toggleStackExpanded() { + this._stackExpanded = !this._stackExpanded; + } + _getToggleIcon(expanded) { + return expanded ? 'expand-less' : 'expand-more'; + } + _resetState() { + this._openedControlPred = false; + this._openedControlSucc = false; + this.set('_groupButtonText', tf_graph_scene_node.getGroupSettingLabel(this._node)); + } + _resizeList(selector) { + var list = document.querySelector(selector) || this.shadowRoot?.querySelector(selector); + if (list) { + list.fire('iron-resize'); + } + } + _nodeIncludeStateChanged(include, oldInclude) { + this.set('_auxButtonText', tf_graph.getIncludeNodeButtonString(include)); + } + _isLibraryFunction(node) { + // If the node name starts with this string, the node is either a + // library function or a node within it. Those nodes should never be + // extracted into the auxiliary scene group because they represent + // templates for function call nodes, not ops in the graph themselves. + return node && node.name.startsWith(tf_graph.FUNCTION_LIBRARY_NODE_PREFIX); + } + _isInSeries(node) { + return tf_graph_scene_node.canBeInSeries(node); + } + @observe('graphHierarchy') + _graphHierarchyChanged() { + this._templateIndex = this.graphHierarchy.getTemplateIndex(); + this.graphHierarchy.addListener(tf_graph_hierarchy.HierarchyEvent.TEMPLATES_UPDATED, () => { + this._templateIndex = this.graphHierarchy.getTemplateIndex(); + }); + } +} +export function getMatched(matched) { + matchStorage = matched; +} +export function getUnMatched(matched) { + unMatchStorage = matched; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-node-list-item.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-node-list-item.ts new file mode 100644 index 0000000000000000000000000000000000000000..9874e825a367e2013a8bdfb15988969f528509bd --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_info/tf-node-list-item.ts @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import { DarkModeMixin } from '../polymer/dark_mode_mixin'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import '../tf_dashboard_common/tensorboard-color'; +import '../tf_graph_common/tf-node-icon'; + +@customElement('tf-node-list-item') +class TfNodeListItem extends LegacyElementMixin(DarkModeMixin(PolymerElement)) { + static readonly template = html` + +
+
+ + [[name]] + [[edgeLabel]] +
+
+ `; + /** + * The Node for the card itself, on which this item is being drawn. + * This property is a tf.graph.Node. + */ + @property({ type: Object }) + cardNode: any; + /** + * The Node for the item within the card, somehow related to cardNode. + * This property is a tf.graph.Node. + */ + @property({ type: Object }) + itemNode: object; + /** The edge label associated with this item. */ + @property({ type: String }) + edgeLabel: string; + /** + * The render node information for the item node. Used by the graph + * icon in determining fill color. + */ + @property({ type: Object }) + itemRenderInfo: object; + @property({ type: String }) + name: string; + @property({ + type: String, + observer: '_itemTypeChanged', + }) + itemType: string; + @property({ type: Object }) + templateIndex: object; + _itemTypeChanged() { + if (this.itemType !== 'subnode') { + this.$['list-item'].classList.add('clickable'); + } else { + this.$['list-item'].classList.remove('clickable'); + } + } + _nodeListener(event) { + // fire node.click/mouseover/mouseout + this.fire('node-list-item-' + event.type, { + cardNode: this.cardNode?.name, + nodeName: this.name, + type: this.itemType, + }); + } + _fadedClass(itemRenderInfo) { + return itemRenderInfo && itemRenderInfo.isFadedOut ? 'faded' : ''; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts new file mode 100644 index 0000000000000000000000000000000000000000..cde9c8570f80aa5628eb29f76012bf21db950bcf --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_loader/tf-graph-dashboard-loader.ts @@ -0,0 +1,313 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, observe, property } from '@polymer/decorators'; +import { PolymerElement } from '@polymer/polymer'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import { getRouter } from '../tf_backend/router'; +import * as tf_graph_common from '../tf_graph_common/common'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; +import * as tf_graph_loader from '../tf_graph_common/loader'; +import * as tf_graph_op from '../tf_graph_common/op'; +import * as tf_graph_parser from '../tf_graph_common/parser'; +import * as tf_graph_util from '../tf_graph_common/util'; +import * as tf_graph_node from '../tf_graph_common/node'; +import * as tf_graph_controls from '../tf_graph_controls/tf-graph-controls'; + +interface GraphRunTag { + run: string; + tag: string | null; +} + +interface Compoments { + Menu: object; + ToolTip: object; + Colors: object; + MicroSteps: number; + StepList: []; + UnMatchedNode: []; + match: []; +} +/** + * Data loader for tf-graph-dashboard. + * + * The loader loads op graph, conceptual graphs, and RunMetadata associated with + * an op graph which is the major difference from the tf-graph-loader which is + * only capable of loading an op graph. Another difference is that the loader + * takes `selection` from the tf-graph-controls as an input as opposed to URL + * path of an data endpoint. + */ +@customElement('tf-graph-dashboard-loader') +class TfGraphDashboardLoader extends LegacyElementMixin(PolymerElement) { + @property({ type: Array }) + datasets: any[]; + /** + * @type {{value: number, msg: string}} + * + * A number between 0 and 100 denoting the % of progress + * for the progress bar and the displayed message. + */ + @property({ + type: Object, + notify: true, + }) + progress: object; + @property({ type: Object }) + selection: any; + /** + * TODO(stephanwlee): This should be changed to take in FileList or + * the prop should be changed to `fileInput`. + * @type {?Event} + */ + @property({ type: Object }) + selectedFile: object; + @property({ + type: Object, + }) + compatibilityProvider = new tf_graph_op.TpuCompatibilityProvider(); + @property({ + type: Object, + }) + hierarchyParams = tf_graph_hierarchy.DefaultHierarchyParams; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + outGraphHierarchy: tf_graph_hierarchy.Hierarchy; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + outGraph: tf_graph.SlimGraph; + @property({ + type: Object, + readOnly: true, // This property produces data. + notify: true, + }) + outStats: object; + @property({ type: Object }) + _graphRunTag: GraphRunTag; + override _template = null; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + menu: object; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + colorset: object; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + tooltips: object; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + colors: any; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + microsteps: any; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + steplist: any; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + unmatched: object; + @property({ + type: Object, + readOnly: true, //readonly so outsider can't change this via binding + notify: true, + }) + matchedlist: object; + getColors() { + return this.colors; + } + @observe('selection', 'compatibilityProvider') + _selectionChanged(): void { + if (!this.selection) { + return; + } + // selection can change a lot within a microtask. + // Don't fetch too much too fast and introduce race condition. + this.debounce('selectionchange', () => { + this._load(this.selection); + }); + } + _load(selection: tf_graph_controls.Selection) { + const { run, tag, type: selectionType, batch, step } = selection; + switch (selectionType) { + case tf_graph_common.SelectionType.OP_GRAPH: + case tf_graph_common.SelectionType.CONCEPTUAL_GRAPH: { + // Clear stats about the previous graph. + (function () { + this._setOutStats(null); + }).bind(this)(); + const params = new URLSearchParams(); + params.set('run', run); + params.set('conceptual', String(selectionType === tf_graph_common.SelectionType.CONCEPTUAL_GRAPH)); + if (tag) params.set('tag', tag); + params.set('batch', String(batch === -1 ? -1 : batch - 1)); + params.set('step', String(step === -1 ? -1 : step - 1)); + const compomentsPath = getRouter().pluginRouteForSrc('graph_ascend', '/compoments', params); + params.set('node', 'root'); + const graphPath = getRouter().pluginRouteForSrc('graph_ascend', '/subgraph', params); + tf_graph_parser.fetchPbTxt(compomentsPath).then( + function (compomentsStr: BufferSource | undefined) { + let compoments: Compoments = { + Menu: [], + ToolTip: {}, + Colors: {}, + MicroSteps: 0, + StepList: [], + UnMatchedNode: [], + match: [], + }; + try { + compoments = JSON.parse(new TextDecoder().decode(compomentsStr).replace(/'/g, '"')) as Compoments; + } catch (e) { + console.error('Parse tooltips failed, please check the format of tooltips in the input vis file'); + } + const entries = Object.entries(compoments.ToolTip); + const toolTipObject = Object.fromEntries(entries); + this._setMenu(compoments.Menu); + this._setTooltips(toolTipObject); // equals to set('tooltips', tooltips) + this._setColors(compoments.Colors); + this._setColorset(Object.entries(compoments.Colors)); + this._setUnmatched(compoments.UnMatchedNode); + this._setMatchedlist(compoments.match); + tf_graph_node.getColors(compoments.Colors); + const microstepsCount = Number(compoments.MicroSteps); + if (microstepsCount) { + const microstepsArray = ['ALL', ...Array.from({ length: microstepsCount }, (_, index) => index)]; + this._setMicrosteps(microstepsArray); + } else { + this._setMicrosteps([]); + } + const steplistCount = Number(compoments.MicroSteps); + if (steplistCount) { + this._setSteplist(compoments.StepList); + } else { + this._setSteplist([]); + } + }.bind(this), + ); + + return this._fetchAndConstructHierarchicalGraph(graphPath).then(() => { + this._graphRunTag = { run, tag }; + }); + } + case tf_graph_common.SelectionType.PROFILE: { + const { tags } = this.datasets.find(({ name }) => name === run); + const tagMeta = tags.find((t) => t.tag === tag); + // In case current tag misses opGraph but has profile information, + // we fallback to the v1 behavior of fetching the run graph. + const requiredOpGraphTag = tagMeta.opGraph ? tag : null; + console.assert( + tags.find((t) => t.tag === requiredOpGraphTag), + `Required tag (${requiredOpGraphTag}) is missing.`, + ); + const shouldFetchGraph = + !this._graphRunTag || this._graphRunTag.run !== run || this._graphRunTag.tag !== requiredOpGraphTag; + const maybeFetchGraphPromise = shouldFetchGraph + ? this._load({ + run, + tag: requiredOpGraphTag, + type: tf_graph_common.SelectionType.OP_GRAPH, + batch, + step, + }) + : Promise.resolve(); + const params = new URLSearchParams(); + params.set('tag', tag!); + params.set('run', run); + const metadataPath = getRouter().pluginRouteForSrc('graph_ascend', '/run_metadata', params); + return maybeFetchGraphPromise.then(() => this._readAndParseMetadata(metadataPath)); + } + default: + return Promise.reject(new Error(`Unknown selection type: ${selectionType}`)); + } + } + _readAndParseMetadata(path: string) { + // Reset the progress bar to 0. + this.set('progress', { + value: 0, + msg: '', + }); + var tracker = tf_graph_util.getTracker(this); + tf_graph_parser.fetchAndParseMetadata(path, tracker).then( + function (stats) { + this._setOutStats(stats); + }.bind(this), + ); + } + _fetchAndConstructHierarchicalGraph(path: string | null, pbTxtFile?: Blob) { + // Reset the progress bar to 0. + this.set('progress', { + value: 0, + msg: '', + }); + const tracker = tf_graph_util.getTracker(this); + return tf_graph_loader + .fetchAndConstructHierarchicalGraph( + tracker, + path, + pbTxtFile !== undefined ? pbTxtFile : null, + this.compatibilityProvider, + this.hierarchyParams, + ) + .then( + function ({ graph, graphHierarchy }) { + this._setOutGraph(graph); + this._setOutGraphHierarchy(graphHierarchy); + }.bind(this), + ); + } + @observe('selectedFile', 'compatibilityProvider') + _selectedFileChanged() { + var e = this.selectedFile; + if (!e) { + return; + } + const target = (e as any).target as HTMLInputElement; + const file = target.files?.[0]; + if (!file) { + return; + } + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + target.value = ''; + this._fetchAndConstructHierarchicalGraph(null, file); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_search/tf-graph-node-search.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_search/tf-graph-node-search.ts new file mode 100644 index 0000000000000000000000000000000000000000..6623419fa0172675ed977f7209b78eb0baf58b8e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_node_search/tf-graph-node-search.ts @@ -0,0 +1,185 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { computed, customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as _ from 'lodash'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tb_debug from '../tb_debug'; +import '../tf_dashboard_common/tensorboard-color'; +import * as tf_graph_util from '../tf_graph_common/util'; + +@customElement('tf-graph-node-search') +class TfGraphNodeSearch extends LegacyElementMixin(PolymerElement) { + static readonly template = html` +
+ +
+
+ +
+
+
+ + `; + @property({ type: Object }) + renderHierarchy: any; + @property({ type: Object }) + menu: any; + @property({ + type: String, + notify: true, + }) + selectedNode: string; + @property({ + type: String, + }) + _rawRegexInput: string = ''; + @property({ + type: String, + }) + // The value of the regex input for the last search. + _previousRegexInput: string = ''; + @property({ + type: Number, + }) + _searchTimeoutDelay: number = 150; + @property({ type: Boolean }) + _searchPending: boolean; + @property({ + type: Number, + }) + _maxRegexResults: number = 42; + @property({ type: Array }) + _regexMatches: unknown[]; + @property({ type: Number }) + selectedSide: number = 0; + // This is the cleaned input. + @computed('renderHierarchy', '_rawRegexInput') + get _regexInput(): string { + var renderHierarchy = this.renderHierarchy; + var rawRegexInput = this._rawRegexInput; + return rawRegexInput.trim(); + } + @observe('_regexInput') + _regexInputChanged() { + var regexInput = this._regexInput; + this._requestSearch(); + } + _clearSearchResults() { + this.set('_regexMatches', []); + } + _requestSearch() { + if (this._searchPending) { + return; + } + if (this._regexInput === this._previousRegexInput) { + // No new search is needed. + this._searchPending = false; + return; + } + this._searchPending = true; + this._executeSearch(); + // After some time, perhaps execute another search. + this.async(() => { + this._searchPending = false; + this._requestSearch(); + }, this._searchTimeoutDelay); + } + _executeSearch() { + this._previousRegexInput = this._regexInput; + if (!this._regexInput) { + this._clearSearchResults(); + return; + } + try { + var regex = new RegExp(this._regexInput, 'i'); + } catch (e) { + // The regular expression is invalid. + this._clearSearchResults(); + return; + } + const matchedNodes: any[] = []; + // const nodeMap = this.renderHierarchy.hierarchy.getNodeMap(); + _.each(this.menu[this.selectedSide], (nodeName) => { + if (matchedNodes.length >= this._maxRegexResults) { + // Terminate. + return false; + } + if (!regex.test(nodeName)) { + return; + } + matchedNodes.push(nodeName); + }); + this.set('_regexMatches', matchedNodes); + } + _matchClicked(e) { + const hasBNode = this.renderHierarchy.renderedOpNames.some((name: string) => name.startsWith('B___')); + let prefix = ''; + if (hasBNode) { + if (this.selectedSide == 0) { + prefix = 'N___'; + } else { + prefix = 'B___'; + } + } + const node = prefix + e.model.item; + this.set('selectedNode', node); + tf_graph_util.notifyDebugEvent({ + actionId: tb_debug.GraphDebugEventId.NODE_SEARCH_RESULT_FOCUSED, + }); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_op_compat_card/tf-graph-op-compat-card.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_op_compat_card/tf-graph-op-compat-card.ts new file mode 100644 index 0000000000000000000000000000000000000000..784c97efc5cdd4f6d673aac4bb4bc33df54be6ff --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_op_compat_card/tf-graph-op-compat-card.ts @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { computed, customElement, observe, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import * as d3 from 'd3'; +import { DarkModeMixin } from '../polymer/dark_mode_mixin'; +import '../polymer/irons_and_papers'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import * as tf_graph from '../tf_graph_common/graph'; +import * as tf_graph_hierarchy from '../tf_graph_common/hierarchy'; +import * as tf_graph_render from '../tf_graph_common/render'; +import './tf-graph-op-compat-list-item'; + +@customElement('tf-graph-op-compat-card') +class TfGraphOpCompatCard extends LegacyElementMixin(DarkModeMixin(PolymerElement)) { + static readonly template = html` + + + + +
+ + +
[[nodeTitle]]
+
+
+
+
+ + + + + + + + + + + +
[[_opCompatScoreLabel]]
+
+
+
+
+
+ + + + + `; + @property({ type: Object }) + graphHierarchy: tf_graph_hierarchy.Hierarchy; + @property({ type: Object }) + renderHierarchy: tf_graph_render.RenderGraphInfo; + @property({ type: String }) + nodeTitle: string; + @property({ + type: Boolean, + }) + _expanded: boolean = true; + @property({ + type: String, + }) + _opCompatColor: string = tf_graph_render.OpNodeColors.COMPATIBLE; + @property({ + type: String, + }) + _opIncompatColor: string = tf_graph_render.OpNodeColors.INCOMPATIBLE; + @property({ type: Object }) + _templateIndex: ((name: string) => number | null) | null = null; + _getNode(nodeName, graphHierarchy) { + return graphHierarchy.node(nodeName); + } + _getRenderInfo(nodeName, renderHierarchy) { + return this.renderHierarchy.getOrCreateRenderNodeByName(nodeName); + } + _toggleExpanded() { + this._expanded = !this._expanded; + } + _getToggleIcon(expanded) { + return expanded ? 'expand-less' : 'expand-more'; + } + _resizeList(selector) { + var list = document.querySelector(selector); + if (list) { + list.fire('iron-resize'); + } + } + @computed('graphHierarchy') + get _incompatibleOpNodes(): Array { + const graphHierarchy = this.graphHierarchy; + if (!graphHierarchy || !graphHierarchy.root) { + return []; + } + this.async(this._resizeList.bind(this, '#incompatibleOpsList')); + return tf_graph_hierarchy.getIncompatibleOps(graphHierarchy); + } + @computed('graphHierarchy') + get _opCompatScore(): number { + var graphHierarchy = this.graphHierarchy; + if (graphHierarchy && graphHierarchy.root) { + var root = graphHierarchy.root; + var numCompat = root.compatibilityHistogram.compatible; + var numIncompat = root.compatibilityHistogram.incompatible; + if (numCompat == 0 && numIncompat == 0) return 0; + var numTotal = numCompat + numIncompat; + // Round the ratio towards negative infinity. + return Math.floor((100 * numCompat) / numTotal) / 100; + } + return 0; + } + @computed('_opCompatScore') + get _opCompatScoreLabel(): string { + var opCompatScore = this._opCompatScore; + return d3.format('.0%')(opCompatScore); + } + @computed('graphHierarchy') + get _totalIncompatOps(): number { + var graphHierarchy = this.graphHierarchy; + if (graphHierarchy && graphHierarchy.root) { + return graphHierarchy.root.compatibilityHistogram.incompatible; + } + return 0; + } + @observe('graphHierarchy') + _graphHierarchyChanged() { + this._templateIndex = this.graphHierarchy.getTemplateIndex(); + this.graphHierarchy.addListener(tf_graph_hierarchy.HierarchyEvent.TEMPLATES_UPDATED, () => { + this._templateIndex = this.graphHierarchy.getTemplateIndex(); + }); + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_op_compat_card/tf-graph-op-compat-list-item.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_op_compat_card/tf-graph-op-compat-list-item.ts new file mode 100644 index 0000000000000000000000000000000000000000..70c7b9e5741d5233e1dec9f7f590c88527845f0c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_graph_op_compat_card/tf-graph-op-compat-list-item.ts @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; +import { LegacyElementMixin } from '../polymer/legacy_element_mixin'; +import '../tf_dashboard_common/tensorboard-color'; +import '../tf_graph_common/tf-node-icon'; + +@customElement('tf-graph-op-compat-list-item') +class TfGraphOpCompatListItem extends LegacyElementMixin(PolymerElement) { + static readonly template = html` + + +
+
+ + + [[name]] +
+
+ `; + /** + * The Node for the card itself, on which this item is being drawn. + * This property is a tf.graph.Node. + */ + @property({ type: Object }) + cardNode: object; + /** + * The Node for the item within the card, somehow related to cardNode. + * This property is a tf.graph.Node. + */ + @property({ type: Object }) + itemNode: object; + /** The edge label associated with this item. */ + @property({ type: String }) + edgeLabel: string; + /** + * The render node information for the item node. Used by the graph + * icon in determining fill color. + */ + @property({ type: Object }) + itemRenderInfo: object; + @property({ type: String }) + name: string; + @property({ + type: String, + observer: '_itemTypeChanged', + }) + itemType: string; + @property({ type: Object }) + templateIndex: (name: string) => number; + _itemTypeChanged() { + if (this.itemType !== 'subnode') { + this.$['list-item'].classList.add('clickable'); + } else { + this.$['list-item'].classList.remove('clickable'); + } + } + _nodeListener(event) { + // fire node.click/mouseover/mouseout + this.fire('node-list-item-' + event.type, { + nodeName: this.name, + type: this.itemType, + }); + } + _fadedClass(itemRenderInfo) { + return itemRenderInfo && itemRenderInfo.isFadedOut ? 'faded' : ''; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/index.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..e0d7e905a2b45bb959b35cea9d1e6d32d8da240c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/index.ts @@ -0,0 +1,16 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +export * from './listeners'; +export * from './storage'; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/listeners.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/listeners.ts new file mode 100644 index 0000000000000000000000000000000000000000..ec410a8f2f8f4c52b6098898443322f9d578bd7c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/listeners.ts @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +export class ListenKey { + public readonly listener: Function; + constructor(listener: Function) { + this.listener = listener; + } +} +const hashListeners = new Set(); +const storageListeners = new Set(); +window.addEventListener('hashchange', () => { + hashListeners.forEach((listenKey) => listenKey.listener()); +}); +// [1]: The event only triggers when another tab edits the storage. Changing a +// value in current browser tab will NOT trigger below event. +window.addEventListener('storage', () => { + storageListeners.forEach((listenKey) => listenKey.listener()); +}); +export function addHashListener(fn: Function): ListenKey { + const key = new ListenKey(fn); + hashListeners.add(key); + return key; +} +export function addStorageListener(fn: Function): ListenKey { + const key = new ListenKey(fn); + storageListeners.add(key); + return key; +} +export function fireStorageChanged() { + storageListeners.forEach((listenKey) => listenKey.listener()); +} +export function removeHashListenerByKey(key: ListenKey) { + hashListeners.delete(key); +} +export function removeStorageListenerByKey(key: ListenKey) { + storageListeners.delete(key); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/storage.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/storage.ts new file mode 100644 index 0000000000000000000000000000000000000000..9deac9bdc477ff6b3483cac490ca4c6b19e17ef7 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/storage.ts @@ -0,0 +1,264 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import * as _ from 'lodash'; +import { + addHashListener, + addStorageListener, + fireStorageChanged, + ListenKey, + removeHashListenerByKey, + removeStorageListenerByKey, +} from './listeners'; +import { + componentToDict, + dictToComponent, + readComponent, + TAB_KEY, + unsetFromURI, + updateUrlDict, + writeComponent, +} from './storage_utils'; + +export {getUrlDict as getUrlHashDict} from './storage_utils'; + +/** + * The name of the property for users to set on a Polymer component + * in order for its stored properties to be stored in the URI unambiguously. + * (No need to set this if you want multiple instances of the component to + * share URI state) + * + * Example: + * + * + * The disambiguator should be set to any unique value so that multiple + * instances of the component can store properties in URI storage. + * + * Because it's hard to dereference this variable in HTML property bindings, + * it is NOT safe to change the disambiguator string without find+replace + * across the codebase. + */ +export const DISAMBIGUATOR = 'disambiguator'; + +export const { + get: getString, + set: setString, + getInitializer: getStringInitializer, + getObserver: getStringObserver, + disposeBinding: disposeStringBinding, +} = makeBindings( + (x) => x, + (x) => x +); +export const { + get: getBoolean, + set: setBoolean, + getInitializer: getBooleanInitializer, + getObserver: getBooleanObserver, + disposeBinding: disposeBooleanBinding, +} = makeBindings( + (s) => (s === 'true' ? true : s === 'false' ? false : undefined), + (b) => b.toString() +); +export const { + get: getNumber, + set: setNumber, + getInitializer: getNumberInitializer, + getObserver: getNumberObserver, + disposeBinding: disposeNumberBinding, +} = makeBindings( + (s) => +s, + (n) => n.toString() +); +export const { + get: getObject, + set: setObject, + getInitializer: getObjectInitializer, + getObserver: getObjectObserver, + disposeBinding: disposeObjectBinding, +} = makeBindings( + (s) => JSON.parse(atob(s)) as Record, + (o) => btoa(JSON.stringify(o)) +); +export interface StorageOptions { + defaultValue?: T; + useLocalStorage?: boolean; +} +export interface AutoStorageOptions extends StorageOptions { + polymerProperty?: string; +} +export interface SetterOptions extends StorageOptions { + defaultValue?: T; + useLocalStorage?: boolean; + useLocationReplace?: boolean; +} +export function makeBindings( + fromString: (string) => T, + toString: (T) => string +): { + get: (key: string, option?: StorageOptions) => T; + set: (key: string, value: T, option?: SetterOptions) => void; + getInitializer: (key: string, options: AutoStorageOptions) => Function; + getObserver: (key: string, options: AutoStorageOptions) => Function; + disposeBinding: () => void; +} { + const hashListeners: ListenKey[] = []; + const storageListeners: ListenKey[] = []; + function get(key: string, options: StorageOptions = {}): T { + const {defaultValue, useLocalStorage = false} = options; + const value = useLocalStorage + ? window.localStorage.getItem(key) + : componentToDict(readComponent())[key]; + return value == undefined ? _.cloneDeep(defaultValue)! : fromString(value)!; + } + function set(key: string, value: T, options: SetterOptions = {}): void { + const { + defaultValue, + useLocalStorage = false, + useLocationReplace = false, + } = options; + const stringValue = toString(value); + if (useLocalStorage) { + window.localStorage.setItem(key, stringValue); + // Because of listeners.ts:[1], we need to manually notify all UI elements + // listening to storage within the tab of a change. + fireStorageChanged(); + } else if (!_.isEqual(value, get(key, {useLocalStorage}))) { + if (_.isEqual(value, defaultValue)) { + unsetFromURI(key, useLocationReplace); + } else { + const items = componentToDict(readComponent()); + items[key] = stringValue; + writeComponent(dictToComponent(items), useLocationReplace); + } + } + } + /** + * Returns a function that can be used on a `value` declaration to a Polymer + * property. It updates the `polymerProperty` when storage changes -- i.e., + * when `useLocalStorage`, it listens to storage change from another tab and + * when `useLocalStorage=false`, it listens to hashchange. + */ + function getInitializer(key: string, options: StorageOptions): Function { + const fullOptions = { + defaultValue: options.defaultValue, + polymerProperty: key, + useLocalStorage: false, + ...options, + }; + return function () { + const uriStorageName = getURIStorageName(this, key); + // setComponentValue will be called every time the underlying storage + // changes and is responsible for ensuring that new state will propagate + // to the component with specified property. It is important that this + // function does not re-assign needlessly, to avoid Polymer observer + // churn. + const setComponentValue = () => { + const storedValue = get(uriStorageName, fullOptions); + const currentValue = this[fullOptions.polymerProperty]; + if (!_.isEqual(storedValue, currentValue)) { + this[fullOptions.polymerProperty] = storedValue; + } + }; + const addListener = fullOptions.useLocalStorage + ? addStorageListener + : addHashListener; + // TODO(stephanwlee): When using fakeHash, it _should not_ listen to the + // window.hashchange. + const listenKey = addListener(() => setComponentValue()); + if (fullOptions.useLocalStorage) { + storageListeners.push(listenKey); + } else { + hashListeners.push(listenKey); + } + // Set the value on the property. + setComponentValue(); + return this[fullOptions.polymerProperty]; + }; + } + function disposeBinding() { + hashListeners.forEach((key) => removeHashListenerByKey(key)); + storageListeners.forEach((key) => removeStorageListenerByKey(key)); + } + function getObserver(key: string, options: StorageOptions): Function { + const fullOptions = { + defaultValue: options.defaultValue, + polymerProperty: key, + useLocalStorage: false, + ...options, + }; + return function () { + const uriStorageName = getURIStorageName(this, key); + const newVal = this[fullOptions.polymerProperty]; + set(uriStorageName, newVal, fullOptions); + }; + } + return {get, set, getInitializer, getObserver, disposeBinding}; +} +export function migrateLegacyURLScheme() { + /** + * TODO(psybuzz): move to some compatibility file. + * For each WIT URL param in the legacy scheme, create another URL param + * in the new scheme. Once WIT migrates to using the new plugin API + * `getURLPluginData()`, we can update this method to delete the legacy + * scheme params. + * + * This list of params was taken on 1/16/2020. Luckily, WIT only stored + * strings, booleans. + */ + const witUrlCompatibilitySet = new Set([ + 'examplesPath', + 'hideModelPane2', + 'modelName1', + 'modelName2', + 'inferenceAddress1', + 'inferenceAddress2', + 'modelType', + 'modelVersion1', + 'modelVersion2', + 'modelSignature1', + 'modelSignature2', + 'maxExamples', + 'labelVocabPath', + 'multiClass', + 'sequenceExamples', + 'maxClassesToDisplay', + 'samplingOdds', + 'usePredictApi', + 'predictInputTensor', + 'predictOutputTensor', + ]); + const items = componentToDict(readComponent()); + if (items[TAB_KEY] === 'whatif') { + for (let oldName of witUrlCompatibilitySet) { + if (oldName in items) { + const oldValue = items[oldName]; + items[`p.whatif.${oldName}`] = oldValue; + } + } + } + writeComponent(dictToComponent(items)); + updateUrlDict(items); +} +/** + * Get a unique storage name for a (Polymer component, propertyName) tuple. + * + * DISAMBIGUATOR must be set on the component, if other components use the + * same propertyName. + */ +function getURIStorageName(component: {}, propertyName: string): string { + const d = component[DISAMBIGUATOR]; + const components = d == null ? [propertyName] : [d, propertyName]; + return components.join('.'); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/storage_utils.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/storage_utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..09a1e57916f3e7ce9e7e11d81b4515ad258c73a1 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/storage_utils.ts @@ -0,0 +1,121 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {getFakeHash, setFakeHash, useHash} from '../tf_globals/globals'; +import {addHashListener} from './listeners'; + +/** + * A keyword that users cannot use, since TensorBoard uses this to store info + * about the active tab. + */ +export const TAB_KEY = '__tab__'; + +export interface StringDict { + [key: string]: string; +} + +// Keep an up-to-date store of URL params, which iframed plugins can request. +let urlDict: StringDict = {}; + +export function getUrlDict(): StringDict { + return urlDict; +} + +export function updateUrlDict(dict: StringDict) { + urlDict = dict; +} + +addHashListener(() => { + urlDict = componentToDict(readComponent()); +}); + +/** + * Read component from URI (e.g. returns "events&runPrefix=train*"). + */ +export function readComponent(): string { + return useHash() ? window.location.hash.slice(1) : getFakeHash(); +} + +/** + * Convert a URI Component into a dictionary of strings. + * Component should consist of key-value pairs joined by a delimiter + * with the exception of the tabName. + * Returns dict consisting of all key-value pairs and + * dict[TAB] = tabName + */ +export function componentToDict(component: string): StringDict { + const items = {} as StringDict; + const tokens = component.split('&'); + tokens.forEach((token) => { + const kv = token.split('='); + // Special backwards compatibility for URI components like #scalars. + if (kv.length === 1) { + items[TAB_KEY] = kv[0]; + } else if (kv.length === 2) { + items[decodeURIComponent(kv[0])] = decodeURIComponent(kv[1]); + } + }); + return items; +} + +/** + * Write component to URI. + */ +export function writeComponent(component: string, useLocationReplace = false) { + if (useHash()) { + if (useLocationReplace) { + const url = new URL(window.location.href); + url.hash = component; + window.history.replaceState(window.history.state, '', url.toString()); + } else { + window.location.hash = component; + } + } else { + setFakeHash(component); + } +} + +/** + * Convert dictionary of strings into a URI Component. + * All key value entries get added as key value pairs in the component, + * with the exception of a key with the TAB value, which if present + * gets prepended to the URI Component string for backwards compatibility + * reasons. + */ +export function dictToComponent(items: StringDict): string { + let component = ''; + // Add the tab name e.g. 'events', 'images', 'histograms' as a prefix + // for backwards compatbility. + if (items[TAB_KEY] !== undefined) { + component += items[TAB_KEY]; + } + // Join other strings with &key=value notation + const nonTab = Object.keys(items) + .map((key) => [key, items[key]]) + .filter((pair) => pair[0] !== TAB_KEY) + .map((pair) => { + return encodeURIComponent(pair[0]) + '=' + encodeURIComponent(pair[1]); + }) + .join('&'); + return nonTab.length > 0 ? component + '&' + nonTab : component; +} + +/** + * Delete a key from the URI. + */ +export function unsetFromURI(key, useLocationReplace = false) { + const items = componentToDict(readComponent()); + delete items[key]; + writeComponent(dictToComponent(items), useLocationReplace); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/tf-storage-polymer.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/tf-storage-polymer.ts new file mode 100644 index 0000000000000000000000000000000000000000..d21722e812f4df0a6ee4fcb9cb9392ea03af942b --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_storage/tf-storage-polymer.ts @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import {customElement} from '@polymer/decorators'; +import {PolymerElement} from '@polymer/polymer'; +import * as tf_storage from './index'; + +@customElement('tf-storage') +class TfStorage extends PolymerElement { + override _template = null; + tf_storage = tf_storage; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_wbr_string/tf-wbr-string.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_wbr_string/tf-wbr-string.ts new file mode 100644 index 0000000000000000000000000000000000000000..1b89de1e5c69a61ae50fb06a6c02d6bbff467f8f --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/tf_wbr_string/tf-wbr-string.ts @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import { computed, customElement, property } from '@polymer/decorators'; +import { html, PolymerElement } from '@polymer/polymer'; + +// tf-wbr-string safely renders a string, with word break elements inserted +// after substrings that match a regular expression pattern. +@customElement('tf-wbr-string') +class TfWbrString extends PolymerElement { + static readonly template = html` + + + `; + @property({ type: String }) + value: string = ''; + + /** + * Regular expression pattern for specifying delimiters. elements + * are inserted after all nonoverlapping matches. A match that is + * overlapped by another match further to the left is ignored. Empty + * matches will consume the remainder of the string so it is advised + * to not allow empty matches in your pattern. + */ + @property({ type: String }) + delimiterPattern: string = ''; + + @computed('value', 'delimiterPattern') + get _parts(): unknown[] { + var value = this.value; + var delimiterPattern = this.delimiterPattern; + const result: string[] = []; + while (true) { + const delimiterRegExp = new RegExp(delimiterPattern, 'g'); + delimiterRegExp.test(value); + if (delimiterRegExp.lastIndex === 0) { + result.push(value); + break; + } else { + result.push(value.slice(0, delimiterRegExp.lastIndex)); + value = value.slice(delimiterRegExp.lastIndex); + } + } + return result; + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/vz_sorting/sorting.ts b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/vz_sorting/sorting.ts new file mode 100644 index 0000000000000000000000000000000000000000..eb1eb1cfd1f1f2efe0a6b3c12f160de304c5f3e1 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/src/vz_sorting/sorting.ts @@ -0,0 +1,111 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/** + * Compares tag names asciinumerically broken into components. + * + *

This is the comparison function used for sorting most string values in + * TensorBoard. Unlike the standard asciibetical comparator, this function + * knows that 'a10b' > 'a2b'. Fixed point and engineering notation are + * supported. This function also splits the input by slash and underscore to + * perform array comparison. Therefore it knows that 'a/a' < 'a+/a' even + * though '+' < '/' in the ASCII table. + */ +export function compareTagNames(a, b: string): number { + let ai = 0; + let bi = 0; + while (true) { + if (ai === a.length) { + return bi === b.length ? 0 : -1; + } + if (bi === b.length) { + return 1; + } + if (isDigit(a[ai]) && isDigit(b[bi])) { + const ais = ai; + const bis = bi; + ai = consumeNumber(a, ai + 1); + bi = consumeNumber(b, bi + 1); + const an = parseFloat(a.slice(ais, ai)); + const bn = parseFloat(b.slice(bis, bi)); + if (an < bn) { + return -1; + } + if (an > bn) { + return 1; + } + continue; + } + if (isBreak(a[ai])) { + if (!isBreak(b[bi])) { + return -1; + } + } else if (isBreak(b[bi])) { + return 1; + } else if (a[ai] < b[bi]) { + return -1; + } else if (a[ai] > b[bi]) { + return 1; + } + ai++; + bi++; + } +} + +function consumeNumber(s: string, i: number): number { + enum State { + NATURAL, + REAL, + EXPONENT_SIGN, + EXPONENT, + } + let state = State.NATURAL; + for (; i < s.length; i++) { + if (state === State.NATURAL) { + if (s[i] === '.') { + state = State.REAL; + } else if (s[i] === 'e' || s[i] === 'E') { + state = State.EXPONENT_SIGN; + } else if (!isDigit(s[i])) { + break; + } + } else if (state === State.REAL) { + if (s[i] === 'e' || s[i] === 'E') { + state = State.EXPONENT_SIGN; + } else if (!isDigit(s[i])) { + break; + } + } else if (state === State.EXPONENT_SIGN) { + if (isDigit(s[i]) || s[i] === '+' || s[i] === '-') { + state = State.EXPONENT; + } else { + break; + } + } else if (state === State.EXPONENT) { + if (!isDigit(s[i])) { + break; + } + } + } + return i; +} + +function isDigit(c: string): boolean { + return '0' <= c && c <= '9'; +} + +function isBreak(c: string): boolean { + // TODO(@jart): Remove underscore when people stop using it like a slash. + return c === '/' || c === '_' || isDigit(c); +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/tsconfig.json b/plugins/tensorboard-plugins/tb_graph_ascend/fe/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..766093086ddd1e885cc003b7a666aa4179bc7f9a --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "downlevelIteration": true, + "emitDecoratorMetadata": true, + "exactOptionalPropertyTypes": true, + "experimentalDecorators": true, + "importHelpers": true, + "inlineSourceMap": true, + "lib": ["dom", "ES2022", "dom.iterable"], + "noImplicitAny": false, + "moduleResolution": "node", + "module": "ES2022", + "noFallthroughCasesInSwitch": true, + "noImplicitReturns": true, + "noImplicitOverride": true, + "skipLibCheck": true, + "strict": true, + "target": "ES6" + } +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/fe/webpack.config.js b/plugins/tensorboard-plugins/tb_graph_ascend/fe/webpack.config.js new file mode 100644 index 0000000000000000000000000000000000000000..c7c0115bc0089a7b71be48082364823787c2f849 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/fe/webpack.config.js @@ -0,0 +1,66 @@ +/* ------------------------------------------------------------------------- + Copyright (c) 2024, Huawei Technologies. + All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--------------------------------------------------------------------------------------------*/ + +const path = require('path'); +const HtmlWebpackPlugin = require('html-webpack-plugin'); +const { CleanWebpackPlugin } = require('clean-webpack-plugin'); +const InlineChunkHtmlPlugin = require('inline-chunk-html-plugin'); + +module.exports = { + entry: { + app: './src/index', + }, + output: { + filename: 'index.js', + path: path.resolve(__dirname, 'dist'), + }, + module: { + rules: [ + { + test: /\.(html)$/, + use: { + loader: 'html-loader', + }, + }, + { + test: /\.ts?$/, + use: { + loader: 'ts-loader', + options: { + transpileOnly: true, + }, + }, + exclude: /node_modules/, + }, + { + test: /\.css$/i, + use: ['style-loader', 'css-loader'], + }, + ], + }, + resolve: { + extensions: ['.ts', '.js'], + }, + plugins: [ + new CleanWebpackPlugin(), + new HtmlWebpackPlugin({ + inject: 'body', + template: './index.html', + }), + new InlineChunkHtmlPlugin(HtmlWebpackPlugin, [/.*/]), + ], +}; diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/__init__.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7677734993ed8fd3c668644524e9c7ab72a2faae --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2e5a53a9e95948c7df84677f59a3f06109b0ca --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/constants.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +MAX_FILE_SIZE = 1024 * 1024 * 1024 +PLUGIN_NAME = 'graph_ascend' +PLUGIN_NAME_RUN_METADATA_WITH_GRAPH = 'graph_ascend_run_metadata_graph' +SETS = { + 'Bench': ('Bench', 'B___', 'N___'), + 'NPU': ('NPU', 'N___', 'B___'), + 'B___': ('Bench', 'N___'), + 'N___': ('NPU', 'B___') +} +NA_DATA = [ + ['Max diff', 'N/A'], + ['Min diff', 'N/A'], + ['Mean diff', 'N/A'], + ['L2norm diff', 'N/A'], + ['MaxRelativeErr', 'N/A'], + ['MinRelativeErr', 'N/A'], + ['MeanRelativeErr', 'N/A'], + ['NormRelativeErr', 'N/A'] +] +PREFIX_MAP = { + 'N___': 'NPU', + 'B___': 'Bench' +} +UNMATCH_NAME_SET =['Max diff', 'Min diff', 'Mean diff', 'L2norm diff', 'MaxRelativeErr', 'MinRelativeErr', 'MeanRelativeErr', 'NormRelativeErr', + 'Cosine', 'MaxAbsErr', 'MaxRelativeErr', 'One Thousandth Err Ratio', 'Five Thousandth Err Ratio'] \ No newline at end of file diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..21a52e8095484f26a26a107649fdf17bddf16784 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/plugin.py @@ -0,0 +1,838 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2024, Huawei Technologies. +# Adapt to the model hierarchical visualization data collected by the msprobe tool +# ============================================================================== +"""The TensorBoard Graphs plugin.""" + +import json +import os +from werkzeug import wrappers, Response, exceptions + +from tensorboard import errors +from tensorboard.backend import http_util +from tensorboard.plugins import base_plugin +from tensorboard.util import tb_logging +from . import constants + +logger = tb_logging.get_logger() + +class GraphsPlugin(base_plugin.TBPlugin): + """Graphs Plugin for TensorBoard.""" + + plugin_name = constants.PLUGIN_NAME + headers = [('X-Content-Type-Options', 'nosniff')] + + def __init__(self, context): + """Instantiates GraphsPlugin via TensorBoard core. + + Args: + context: A base_plugin.TBContext instance. + """ + super().__init__(context) + self._data_provider = context.data_provider + self.logdir = os.path.abspath(context.logdir.rstrip('/')) + self._current_file_path = None # Store the path of the currently loaded file + self._current_file_data = None # Store the data of the currently loaded file + self._current_tag = None # Store the tag of the currently loaded file + self.batch_id = 0 # 将 batch_id 声明为实例变量 + self.step_id = 0 # 可以同样声明 step_id + self.dfs_node_ids = [] # batch和step没变的话就将所有的nodename存起来,方便快速读取 + self.check_batch_id = 0 # 来配合node_ids监察用的,他不变node_ids就不用重新读取了 + self.check_step_id = 0 # 同上 + self.check_tag = None + + def get_plugin_apps(self): + return { + '/index.js': self.static_file_route, + '/index.html': self.static_file_route, + "/info": self.info_route, + "/compoments": self.get_all_data, + "/expandnodes": self.get_all_upnodes, + "/precision": self.get_all_precisonNodes, + "/unmatch": self.get_unmatch, + "/match": self.get_match, + "/parent": self.get_parent_node, + "/subgraph": self.subgraph_route, + } + + def is_active(self): + """The graphs plugin is active iff any run has a graph.""" + for _, _, files in os.walk(self.logdir): + for file in files: + if file.endswith('.vis'): + return True + return False + + def data_plugin_names(self): + return ( + constants.PLUGIN_NAME, + constants.PLUGIN_NAME_RUN_METADATA_WITH_GRAPH, + ) + + def frontend_metadata(self): + return base_plugin.FrontendMetadata( + es_module_path='/index.js', + disable_reload=True, + ) + + def info_impl(self): + """Returns a dict of all runs and their data availabilities, including a flag indicating if a .vis file is present.""" + result = {} + + def add_row_item(run, tag=None, is_vis=False): + run_item = result.setdefault( + run, + { + "run": run, + "tags": {}, + "run_graph": False, + }, + ) + + tag_item = None + if tag: + tag_item = run_item.get("tags").setdefault( + tag, + { + "tag": tag, + "conceptual_graph": False, + "op_graph": False, + "profile": False, + }, + ) + return (run_item, tag_item) + + run_tag_pairs = self._get_run_dirs() + for run, tag in run_tag_pairs: + add_row_item(run, tag, is_vis=True) + return result + + # 拿jsondata的 + def get_jsondata(self, request): + run = request.args.get("run") + tag = request.args.get("tag") + json_data = None + error_message = None + + if run is None or tag is None: + error_message = 'The query parameters "run" and "tag" are required' + return json_data, error_message + + run_dir = os.path.join(self.logdir, run) + file_path = self._load_json_file(run_dir, tag) + json_data = self._read_json_file(file_path) + self._current_file_data = json_data + if json_data is None: + error_message = f'vis file for tag "{tag}" not found in run "{run}"' + + return json_data, error_message + + # 拿所有nodename的 + def get_all_nodeName(self, json_data, request): + npu_ids, bench_ids = [], [] + try: + batch = int(request.args.get("batch")) + step = int(request.args.get("step")) + except ValueError: + logger.error('The param "batch" or "step" does not exist or not a valid value') + # 获取 NPU 和 Bench 数据 + npu_data = self.json_get(json_data, 'NPU') + bench_data = self.json_get(json_data, 'Bench') + def extract_ids(nodes_data, id_list): + for node_name in nodes_data.get("node"): + id_list.append(node_name) + def traverse_npu(subnodes): + for node in subnodes: + node_data = self.json_get(npu_data, 'node', node) if npu_data else self.json_get(json_data, 'node', node) + micro_step_id = node_data.get('micro_step_id') + if micro_step_id == batch or micro_step_id is None: + npu_ids.append(node) + traverse_npu(node_data.get('subnodes', [])) + + # 提取 NPU 节点 ID + if batch == -1 and step == -1: + extract_ids(npu_data or json_data, npu_ids) + else: + root = (npu_data or json_data).get('root') + root_subnodes = self.json_get((npu_data or json_data), 'node', root, 'subnodes') + traverse_npu(root_subnodes) + + # 提取 Bench 节点 ID + extract_ids(bench_data or json_data, bench_ids) + # 返回格式为 [[NPU节点ID列表], [Bench节点ID列表]] + return [npu_ids, bench_ids] + + def dfs_collect_nodes(self, json_data, request): + all_node_names = [] + try: + batch = int(request.args.get("batch")) + step = int(request.args.get("step")) + except ValueError: + logger.error('The param "batch" or "step" does not exist or not a valid value') + def should_include_node(micro_step_id, step_id): + return (micro_step_id is batch or batch == -1 or micro_step_id is None) and (step_id is step or step == -1 or step_id is None) + + nodes_data = self.json_get(json_data, 'NPU', 'node') or self.json_get(json_data, 'node') + for node in nodes_data: + micro_step_id = self.json_get(nodes_data, node, 'micro_step_id') + step_id = self.json_get(nodes_data, node, 'step_id') + if should_include_node(micro_step_id, step_id) and not self.json_get(nodes_data, node, 'subnodes'): + all_node_names.append(node) + return all_node_names + + #拿所有precisonNodes的,与controls的精度筛选联动 + @wrappers.Request.application + def get_all_precisonNodes(self, request): + grouped_precision_set, precision_node_ids = [], [] + precision_set = request.args.get("precison") + precision_set_str = precision_set.split(',') + precision_none = 0 + if '无匹配节点' in precision_set_str: + precision_set_str = [p for p in precision_set_str if p != '无匹配节点'] + precision_none = 1 + grouped_precision_set = [list(map(float, precision_set_str[i:i+2])) for i in range(0, len(precision_set_str), 2)] + tag = request.args.get("tag") + json_data = self.check_jsondata(request) + def has_conditions_changed(tag): + return ( + self.check_batch_id != self.batch_id or + self.check_step_id != self.step_id or + self.check_tag != tag or + self.check_tag is None + ) + if has_conditions_changed(tag): + self.dfs_node_ids = self.dfs_collect_nodes(json_data, request) + self.check_batch_id = self.batch_id + self.check_step_id = self.step_id + self.check_tag = tag + node_ids = self.dfs_node_ids + for node in node_ids: + node_data = self.json_get(json_data, 'NPU', 'node', node, 'data') or self.json_get(json_data, 'node', node, 'data') + precision = node_data.get('precision_index') if node_data is not None else None + # 检查 precision 是否在 grouped_precision_set 的任何子列表中 + if precision is not None: + for group in grouped_precision_set: + if all(g is not None for g in group) and group[0] <= precision <= group[1]: # 判断 precision 是否在某个子列表中 + precision_node_ids.append(node) + else: + if precision_none == 1: + precision_node_ids.append(node) + return http_util.Respond(request, precision_node_ids, "application/json") + + def group_precision_set(self, precision_set): + if len(precision_set) % 2 != 0: + raise ValueError('The number of elements in precision_set is not even') + grouped_precision_set = [precision_set[i:i+2] for i in range(0, len(precision_set), 2)] + return grouped_precision_set + + def get_all_unmatchedNodes(self, allNodeName, request): + json_data = self.check_jsondata(request) + is_npu_present = 'NPU' in json_data + def collect_unmatched_nodes(node_list, *path): + return [node for node in node_list if not self.json_get(json_data, *path, node, 'matched_node_link')] + NPU_Unmatched = collect_unmatched_nodes(allNodeName[0], 'NPU', 'node') if is_npu_present else \ + collect_unmatched_nodes(allNodeName[0], 'node') + Bench_Unmatched = collect_unmatched_nodes(allNodeName[1], 'Bench', 'node') if is_npu_present else [] + return [NPU_Unmatched, Bench_Unmatched] + + @wrappers.Request.application + def get_match(self, request): + NPU_node = request.args.get("NPU")[4:] + Bench_node = request.args.get("Bench")[4:] + run = request.args.get('run') + tag = request.args.get('tag') + file_path = os.path.join(run, f"{tag}.vis") + json_data = self.check_jsondata(request) + NPU_node_data = self.json_get(json_data, 'NPU', 'node', NPU_node) + Bench_node_data = self.json_get(json_data, 'Bench', 'node', Bench_node) + # 获取输入输出数据 + NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data = self.get_input_output_data( + NPU_node_data, Bench_node_data) + # 计算输入输出的最小长度 + input_min_length, output_min_length = self.calculate_min_length(NPU_input_data, Bench_input_data, + NPU_output_data, Bench_output_data) + precision_index, max_precision_index, result = -1, -1, 0 + input, output = [], [] + if input_min_length > 0: + npu_input_keys, bench_input_keys = self.get_keys(NPU_input_data, Bench_input_data) + # 处理输入数据 + data_type = 'input_data' + data_set = { + 'input_data': NPU_input_data, + 'output_data': Bench_input_data, + 'npu_keys': npu_input_keys, + 'bench_keys': bench_input_keys, + 'data_type': data_type, + 'precision_index': precision_index, + 'file_path': file_path + } + max_precision_index = self.process_input_data(input, data_set, NPU_node) + if max_precision_index > precision_index: + precision_index = max_precision_index + + if output_min_length > 0: + npu_output_keys, bench_output_keys = self.get_keys(NPU_output_data, Bench_output_data) + # 处理输出数据 + data_type = 'output_data' + data_set = { + 'input_data': NPU_output_data, + 'output_data': Bench_output_data, + 'npu_keys': npu_output_keys, + 'bench_keys': bench_output_keys, + 'data_type': data_type, + 'precision_index': precision_index, + 'file_path': file_path + } + max_precision_index = self.process_output_data(output, data_set, NPU_node) + if max_precision_index > precision_index: + precision_index = max_precision_index + + result = [request.args.get("NPU"), precision_index, input, output] + + if not input and not output and precision_index == -1: + return http_util.Respond(request, [], "application/json") + + # 添加匹配的节点链接 + matched_result = self.add_matched_node_link(NPU_node, Bench_node) + if not matched_result: + result = {"error": f"Unable to link matched nodes for NPU_node: {NPU_node} and Bench_node: {Bench_node}"} + return http_util.Respond(request, result, "application/json", status_code=400) + + precision_result = self.add_precision_index(NPU_node, precision_index) + if not precision_result: + result = {"error": f"Unable to add precision index for NPU_node: {NPU_node} with precision_index: {precision_index}"} + return http_util.Respond(request, result, "application/json", status_code=400) + + # 保存修改后的内容 + self.save_results(file_path) + + return http_util.Respond(request, result, "application/json") + + def get_input_output_data(self, NPU_node_data, Bench_node_data): + """提取输入输出数据""" + NPU_input_data = self.json_get(NPU_node_data, 'input_data') + Bench_input_data = self.json_get(Bench_node_data, 'input_data') + NPU_output_data = self.json_get(NPU_node_data, 'output_data') + Bench_output_data = self.json_get(Bench_node_data, 'output_data') + return NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data + + def calculate_min_length(self, NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data): + """计算输入输出数据的最小长度""" + input_min_length, output_min_length = 0, 0 + if NPU_input_data and Bench_input_data: + input_min_length = min(len(NPU_input_data), len(Bench_input_data)) + if NPU_output_data and Bench_output_data: + output_min_length = min(len(NPU_output_data), len(Bench_output_data)) + return input_min_length, output_min_length + + def get_keys(self, NPU_data, Bench_data): + """提取数据的keys""" + npu_keys = list(NPU_data.keys()) + bench_keys = list(Bench_data.keys()) + return npu_keys, bench_keys + + def safe_float_conversion(self, value, default=None): + """安全地转换为float""" + try: + return float(value) + except ValueError: + return None + + def calculate_diff_and_relative_error(self, npu_data, bench_data): + """计算最大值、最小值、均值和L2范数的差异以及相对误差""" + results = {} + + for key in ['Max', 'Min', 'Mean', 'Norm']: + NPU_value = self.safe_float_conversion(npu_data.get(key, 0)) + Bench_value = self.safe_float_conversion(bench_data.get(key, 0)) + + # 计算差异 + if NPU_value is not None: + results[f'NPU_{key}'] = NPU_value + else: + results[f'NPU_{key}'] = 'N/A' + if Bench_value is not None: + results[f'Bench_{key}'] = Bench_value + else: + results[f'Bench_{key}'] = 'N/A' + + # 计算相对误差 + if NPU_value is not None and Bench_value is not None and Bench_value != 0: + results[f'{key}_relative_err'] = f"{abs(NPU_value - Bench_value) / float(Bench_value) * 100:.6f}%" + else: + results[f'{key}_relative_err'] = "N/A" + + return results + + def process_data(self, data, data_set, NPU_node): + """处理数据的公共逻辑""" + for npu_key, bench_key in zip(data_set['npu_keys'], data_set['bench_keys']): + if data_set['input_data'].get(npu_key) not in ['None', None] and data_set['output_data'].get(bench_key) not in ['None', None]: + npu_value = data_set['input_data'][npu_key] + bench_value = data_set['output_data'][bench_key] + + if self._current_file_data.get('task') == 'md5': + if data_set['input_data'].get('md5') != data_set['output_data'].get('md5'): + data_set['precision_index'] = 0 + break + data_set['precision_index'] = 1 + + # 计算差异和相对误差 + results = self.calculate_diff_and_relative_error(npu_value, bench_value) + + # 提取计算的结果 + NPU_max, Bench_max = results['NPU_Max'], results['Bench_Max'] + NPU_min, Bench_min = results['NPU_Min'], results['Bench_Min'] + NPU_mean, Bench_mean = results['NPU_Mean'], results['Bench_Mean'] + NPU_norm, Bench_norm = results['NPU_Norm'], results['Bench_Norm'] + + max_relative_err = results['Max_relative_err'] + min_relative_err = results['Min_relative_err'] + mean_relative_err = results['Mean_relative_err'] + norm_relative_err = results['Norm_relative_err'] + + # 创建数据条目 + data_entry = [ + ['Max diff', NPU_max - Bench_max], + ['Min diff', NPU_min - Bench_min], + ['Mean diff', NPU_mean - Bench_mean], + ['L2norm diff', NPU_norm - Bench_norm], + ['MaxRelativeErr', max_relative_err], + ['MinRelativeErr', min_relative_err], + ['MeanRelativeErr', mean_relative_err], + ['NormRelativeErr', norm_relative_err] + ] + + # 更新最大差异 + Max_value = max(abs(NPU_max - Bench_max), abs(NPU_min - Bench_min), abs(NPU_mean - Bench_mean), abs(NPU_norm - Bench_norm)) + if Max_value > data_set['precision_index']: + data_set['precision_index'] = Max_value + + # 将数据添加到列表中 + data.append(data_entry) + + # 保存数据到文件 + self.save_data(data_set['file_path'], NPU_node, npu_key, data_entry, data_set['data_type']) + + return data_set['precision_index'] + + def process_input_data(self, input, data_set, NPU_node): + """处理输入数据""" + return self.process_data(input, data_set, NPU_node) + + def process_output_data(self, output, data_set, NPU_node): + """处理输出数据""" + return self.process_data(output, data_set, NPU_node) + + def save_data(self, file_path, NPU_node, npu_key, data, data_type): + """保存数据到文件""" + for key, value in data: + self.json_get(self._current_file_data, "NPU", "node", NPU_node, data_type, npu_key)[key] = value + with open(file_path, "w", encoding="utf-8") as file: + json.dump(self._current_file_data, file, ensure_ascii=False, indent=4) + + def add_matched_node_link(self, NPU_node, Bench_node): + """添加匹配的节点链接""" + # 获取匹配的节点链接 + NPU_matched_node_link = self.json_get(self._current_file_data, "NPU", "node", NPU_node, "matched_node_link") + Bench_matched_node_link = self.json_get(self._current_file_data, "Bench", "node", Bench_node, "matched_node_link") + + # 检查是否为可变列表类型 + if isinstance(NPU_matched_node_link, list) and isinstance(Bench_matched_node_link, list): + NPU_matched_node_link.append(Bench_node) + Bench_matched_node_link.append(NPU_node) + else: + logger.error(f'Error: add matched_node_link failed.') + return False + + # 更新 match 列表 + if self._current_file_data.get('match') is None: + self._current_file_data["match"] = [] + + self._current_file_data.get('match').append([NPU_node, Bench_node]) + return True + + def add_precision_index(self, NPU_node, precision_index): + """添加精度误差""" + # 获取 NPU 数据 + NPU_precison_index = self.json_get(self._current_file_data, "NPU", "node", NPU_node, "data") + + # 检查 NPU_precison_index 是否为字典类型 + if isinstance(NPU_precison_index, dict): + NPU_precison_index["precision_index"] = precision_index + else: + logger.error(f'Error: add precision_index failed.') + return False + return True + + def save_results(self, file_path): + """保存修改后的结果""" + with open(file_path, "w", encoding="utf-8") as file: + json.dump(self._current_file_data, file, ensure_ascii=False, indent=4) + + + @wrappers.Request.application + def get_unmatch(self, request): + # 需要处理的项目为input, output, matched_node_link data里面的precision_index 和 来判断的 + NPU_node = request.args.get("NPU", "")[4:] # 防止获取为空字符串 + Bench_node = request.args.get("Bench", "")[4:] + run = request.args.get('run', "") + tag = request.args.get('tag', "") + file_path = os.path.join(run, f"{tag}.vis") + + # precision_index + try: + npu_node_data = self._current_file_data.get("NPU", {}).get("node", {}).get(NPU_node, {}).get('data', {}) + if "precision_index" in npu_node_data: + del npu_node_data["precision_index"] + except Exception as e: + error = {"error": f"Error removing precision_index: {str(e)}"} + return http_util.Respond(request, error, "application/json", status_code=500) + + # matched_node_link + try: + self._current_file_data["NPU"]["node"][NPU_node]["matched_node_link"] = [] + self._current_file_data["Bench"]["node"][Bench_node]["matched_node_link"] = [] + except Exception as e: + error = {"error": f"Error resetting matched_node_link: {str(e)}"} + return http_util.Respond(request, error, "application/json", status_code=500) + + # input + input_data = self.json_get(self._current_file_data, "NPU", "node", NPU_node, "input_data") + if input_data is None: + error = {"error": f"Error resetting input_data: {str(e)}"} + return http_util.Respond(request, error, "application/json", status_code=500) + for item in input_data: + keys_to_remove = [key for key, value in input_data[item].items() if key in constants.UNMATCH_NAME_SET] + for key in keys_to_remove: + del input_data[item][key] + + # output + output_data = self.json_get(self._current_file_data, "NPU", "node", NPU_node, "output_data") + if output_data is None: + error = {"error": f"Error resetting output_data: {str(e)}"} + return http_util.Respond(request, error, "application/json", status_code=500) + for item in output_data: + keys_to_remove = [key for key, value in output_data[item].items() if key in constants.UNMATCH_NAME_SET] + for key in keys_to_remove: + del output_data[item][key] + # match + self._current_file_data.get('match').remove([NPU_node, Bench_node]) + + with open(file_path, "w", encoding="utf-8") as file: + json.dump(self._current_file_data, file, ensure_ascii=False, indent=4) + result = 'unmatched!!!' + return http_util.Respond(request, result, "application/json") + + @wrappers.Request.application + def get_parent_node(self, request): + node = request.args.get("node")[4:] # 获取节点信息 + prefix = request.args.get("node")[:4] # 获取前缀 + json_data = self.check_jsondata(request) # 检查请求中的 JSON 数据 + + def find_upnode(node): + matched_node_link_list = self.json_get(json_data, constants.PREFIX_MAP[prefix], 'node', node, 'matched_node_link') + + if matched_node_link_list: + result = matched_node_link_list[-1] # 获取匹配的最后一个节点 + return http_util.Respond(request, result, "application/json") # 返回响应 + + # 如果没有找到 matched_node_link,继续递归查找上级节点 + else: + upnode = self.json_get(json_data, constants.PREFIX_MAP[prefix], 'node', node, 'upnode') + if upnode: + return find_upnode(upnode) # 递归查找上级节点 + else: + return http_util.Respond(request, {}, "application/json") # 如果没有找到上级节点,返回空响应 + + return find_upnode(node) + + #拿json_data里面所有配置数据的 + @wrappers.Request.application + def get_all_data(self, request): + """Returns all data in json format.""" + keys = ['ToolTip', 'Colors'] + response_data = {} + tag = request.args.get("tag") + json_data = self.check_jsondata(request) + allNodeName = self.get_all_nodeName(json_data, request) + response_data['Menu'] = allNodeName + response_data['UnMatchedNode'] = self.get_all_unmatchedNodes(allNodeName , request) + self._current_tag = tag + for field in ['MicroSteps', 'StepList', 'match']: + if json_data.get(field, {}): + keys.append(field) + for key in keys: + if key == 'StepList' and 'ALL' not in json_data.get('StepList', {}): + json_data[key].insert(0, 'ALL') + response_data[key] = json_data.get(key, {}) + return http_util.Respond(request, response_data, "application/json") + + @wrappers.Request.application + def static_file_route(self, request): + filename = os.path.basename(request.path) + extension = os.path.splitext(filename)[1] + if extension == '.html': + mimetype = 'text/html' + elif extension == '.js': + mimetype = 'application/javascript' + else: + mimetype = 'application/octet-stream' + filepath = os.path.join(os.path.dirname(__file__), 'static', filename) + try: + with open(filepath, 'rb') as infile: + contents = infile.read() + except IOError as e: + raise exceptions.NotFound('404 Not Found') from e + return Response( + contents, content_type=mimetype, headers=GraphsPlugin.headers + ) + + #方便多层级展开的upnodes节点集合,与tf-graph的_menuSelectedNodeExpand联动 + @wrappers.Request.application + def get_all_upnodes(self, request): + npu_upnodes_list, matched_upnodes_list, node_list = [], [], [] + node, matched_node, prefix = '', '', '' + node_arg = request.args.get('node') + json_data = self.check_jsondata(request) + prefix = str(node_arg)[:4] if str(node_arg)[:4] in constants.PREFIX_MAP else '' + node = node_arg[4:] if prefix in constants.PREFIX_MAP else node_arg + if prefix in constants.PREFIX_MAP and json_data.get(constants.PREFIX_MAP[prefix], {}): + node_list = json_data[constants.PREFIX_MAP[prefix]].get('node', {}) + else: + node_list = json_data.get('node', {}) + matched_node = ( + node_list.get(node, {}).get('matched_node_link', [])[-1] + if node_list.get(node, {}).get('matched_node_link') + else None + ) + def get_upnodes(node, prefix): + upnodes_list = [] + if prefix == '': + node_list = json_data.get('node', {}) + else: + node_list = json_data.get('NPU' if prefix == 'N___' else 'Bench', {}).get('node', {}) + while node in node_list: + upnode = node_list[node].get('upnode') + if not upnode or upnode == 'None': + break + upnodes_list.insert(0, upnode) + node = upnode + return upnodes_list + npu_upnodes_list = get_upnodes(node, prefix) + # 如果 matched_node 是 None 的话 + if matched_node is None: + previous_node = None # 用于跟踪上一个 node + for node in reversed(npu_upnodes_list): + if node_list.get(node, {}).get('matched_node_link'): # 判断条件 + matched_node = previous_node # 将 matched_node 设置为上一个 node + break + previous_node = node # 更新 previous_node 为当前 node + if prefix in constants.PREFIX_MAP: + matched_upnodes_list = get_upnodes(matched_node, prefix) + return http_util.Respond(request, [[prefix], npu_upnodes_list, matched_upnodes_list], "application/json") + + # 检查到底是读一般还是用之前存的 + def check_jsondata(self, request): + tag = request.args.get("tag") + if self._current_tag is None or self._current_tag != tag: + json_data, error_message = self.get_jsondata(request) + if error_message: + return http_util.Respond(request, error_message, "text/plain", 400) + else: + json_data = self._current_file_data + return json_data + + # 处理xx.get + def json_get(self, data, *args): + result = data + for key in args: + if result is None: + return None + result = result.get(key) + return result + + # 获取子图数据,最核心且基本的所在 + @wrappers.Request.application + def subgraph_route(self, request): + """Returns a subgraph for a given node id, modified to use run and tag from query parameters.""" + json_data = self.check_jsondata(request) + node_id = request.args.get("node") + self.batch_id = request.args.get("batch") + self.step_id = request.args.get("step") + if node_id is None: + return http_util.Respond( + request, 'The query parameter "node" is required', "text/plain", 400 + ) + if node_id == 'root': + if json_data.get('Bench', {}): + subgraph_pbtxt_set = {} + for node_type in ('Bench', 'NPU'): + subgraph = {'node': {}, 'edge': {}} + node = self.json_get(json_data, constants.SETS[node_type][0], 'root') + node_data = self.json_get(json_data ,constants.SETS[node_type][0], 'node', node) + node = constants.SETS[node_type][1] + node + matched_node_link = node_data['matched_node_link'] + if matched_node_link[0][:4] != constants.SETS[node_type][2]: + matched_node_link[0] = constants.SETS[node_type][2] + matched_node_link[0] + subgraph['node'][node] = node_data + subgraph_pbtxt_set[node_type] = self._convert_to_protobuf_format(subgraph) + subgraph_pbtxt = subgraph_pbtxt_set.get('NPU', '') + subgraph_pbtxt_set.get('Bench', '') + else: + subgraph = {'node': {}, 'edge': {}} + node = json_data.get('root') + node_data = self.json_get(json_data, 'node', node) + subgraph['node'][node] = node_data + subgraph_pbtxt = self._convert_to_protobuf_format(subgraph) + else: + subgraph = self._extract_subgraph(json_data, node_id) + subgraph_pbtxt = self._convert_to_protobuf_format(subgraph) + return http_util.Respond(request, subgraph_pbtxt, "text/x-protobuf") + + @wrappers.Request.application + def info_route(self, request): + info = self.info_impl() + return http_util.Respond(request, info, "application/json") + + #同上二者一体 + def _extract_subgraph(self, json_data, node_id): + """提取子图,支持多种节点前缀逻辑""" + subgraph = {'node': {}, 'edge': []} + + # 检查前缀并获取节点集合 + prefix = node_id[:4] + if prefix in constants.SETS and len(prefix) == 4: + node_id = node_id[4:] + node_set = self.json_get(json_data, constants.SETS[prefix][0], 'node') + else: + prefix = '' + node_set = json_data.get('node', {}) + + # 获取当前节点数据 + node_data = node_set.get(node_id, {}) + subnodes = node_data.get('subnodes', []) + + # 遍历子节点 + for subnode_id in subnodes: + subnode_id_data = node_set.get(subnode_id, {}) + if subnode_id_data.get('micro_step_id') is not None: + self._process_subnode(subgraph, prefix, subnode_id, subnode_id_data, json_data) + else: + self._process_non_root_subnode(subgraph, prefix, subnode_id, subnode_id_data) + + return subgraph + + def _process_non_root_subnode(self, subgraph, prefix, subnode_id, subnode_id_data): + """处理非根子节点""" + # 更新匹配的节点链接 + self._update_matched_node_links(subnode_id_data, prefix) + + # 添加前缀并存入子图 + full_subnode_id = prefix + subnode_id + subgraph['node'][full_subnode_id] = subnode_id_data + + #针对分micro_step_id和step_id取的部分节点 + def _process_subnode(self, subgraph, prefix, subnode_id, subnode_id_data, json_data): + batchid = subnode_id_data.get('micro_step_id') + stepid = subnode_id_data.get('step_id') + steplist = json_data.get('StepList') + + def should_update_node(): + """判断是否需要更新节点的条件逻辑""" + if self.batch_id == '-1': + if self.step_id == '-1': # batch_id 和 step_id 都为 -1 + return True + return stepid == str(steplist[int(self.step_id) + 1]) # 匹配 step_id + else: # batch_id 有效 + if self.step_id != '-1': # step_id 有效 + return batchid == int(self.batch_id) and stepid == str(steplist[int(self.step_id) + 1]) + return batchid == int(self.batch_id) # 仅匹配 batch_id + + if should_update_node(): + self._update_matched_node_links(subnode_id_data, prefix) + subnode_id = prefix + subnode_id + subgraph['node'][subnode_id] = subnode_id_data + + def _update_matched_node_links(self, subnode_id_data, prefix): + if 'matched_node_link' in subnode_id_data: + for index, matched_node_link in enumerate(subnode_id_data['matched_node_link']): + if matched_node_link[:4] != constants.SETS[prefix][1]: + matched_node_link = constants.SETS[prefix][1] + matched_node_link + subnode_id_data['matched_node_link'][index] = matched_node_link + + #拼接成类json + def _convert_to_protobuf_format(self, subgraph): + """Converts subgraph data to the protobuf text format expected by the frontend.""" + nodes = subgraph.get('node', {}) + protobuf_format = "" + for node_id, node_data in nodes.items(): + protobuf_format += f'node {{\n name: "{node_id}"\n op: "{node_data.get("id")}"\n' + protobuf_format += f' node_type: {node_data.get("node_type", 0)}\n' + if node_data.get("matched_node_link"): + protobuf_format += f' matched_node_link: {node_data.get("matched_node_link")}\n' + protobuf_format += f' attr: "{node_data.get("data", "{}")}"\n'.replace('True', 'true').replace('False', 'false') + protobuf_format += f' precision_index: {(node_data.get("data", "{}").get("precision_index"))}\n' + if node_data.get("input_data"): + protobuf_format += f' input_data: "{node_data.get("input_data", "{}")}"\n' + if node_data.get("output_data"): + protobuf_format += f' output_data: "{node_data.get("output_data", "{}")}"\n' + protobuf_format += f' suggestions: "{node_data.get("suggestions", "{}")}"\n' + if not node_data.get("subnodes"): + protobuf_format += f' isLeaf: true\n' + else: + protobuf_format += f' isLeaf: false\n' + protobuf_format += f' subnodes: {node_data.get("subnodes")}\n' + if node_data.get("stack_info"): + protobuf_format += f' stack_info: {node_data.get("stack_info")}\n' + protobuf_format += '}\n' + return protobuf_format + + def _get_run_dirs(self): + """Scan logdir for directories containing .vis files, modified to return a tuple of (run, tag).""" + run_tag_pairs = [] + for root, _, files in os.walk(self.logdir): + for file in files: + if file.endswith('.vis'): # check for .vis extension + run = os.path.abspath(root) + tag = os.path.splitext(file)[0] # Use the filename without extension as tag + file_path = os.path.join(root, file) + file_size = os.path.getsize(file_path) + if file_size > constants.MAX_FILE_SIZE: + logger.error(f'Error: the vis file "{file_path}" exceeds the maximum limit size of 1GB and will be skipped.') + continue + run_tag_pairs.append((run, tag)) + return run_tag_pairs + + def _load_json_file(self, run_dir, tag): + """Load a single .vis file from a given directory based on the tag.""" + file_path = os.path.join(run_dir, f"{tag}.vis") + if os.path.exists(file_path): + # Store the path of the current file instead of loading it into memory + self._current_file_path = file_path + return file_path + return None + + def _read_json_file(self, file_path): + """Read and parse a JSON file from disk.""" + if file_path and os.path.exists(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + try: + return json.load(f) + except Exception as e: + logger.error(f'Error: the vis file "{file_path}" is not a legal JSON file!') + else: + logger.error(f'Error: the vis file "{file_path}" is not a legal JSON file!') + return None diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/static/index.js b/plugins/tensorboard-plugins/tb_graph_ascend/server/static/index.js new file mode 100644 index 0000000000000000000000000000000000000000..e64f87ec8631da37606fe740e1cdfe9155d34986 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/static/index.js @@ -0,0 +1,6 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + */ +export async function render() { + document.location.href = 'index.html'; +} diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/setup.py b/plugins/tensorboard-plugins/tb_graph_ascend/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..10ba16cbc1646f5f76ec5e5e4c6ffbe19a5ef069 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/setup.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------- +# Copyright (c) 2024, Huawei Technologies. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------------------------------# +import setuptools + +VERSION = '0.1.0' +INSTALL_REQUIRED = [ + "tensorboard >= 2.11.2" +] + + +setuptools.setup( + name="tb-graph-ascend", + version=VERSION, + description="Model Hierarchical Visualization TensorBoard Plugin", + long_description="Model Hierarchical Visualization TensorBoard Plugin : \ + https://gitee.com/ascend/att/tree/master/plugins/tensorboard-plugins/tb_graph_ascend", + url="https://gitee.com/ascend/att/tree/master/plugins/tensorboard-plugins/tb_graph_ascend", + author="Ascend Team", + author_email="pmail_mindstudio@huawei.com", + packages=setuptools.find_packages(), + package_data={ + "server": ["static/**"], + }, + entry_points={ + "tensorboard_plugins": [ + "graph_ascend = server.plugin:GraphsPlugin", + ], + }, + python_requires=">=3.7", + install_requires=INSTALL_REQUIRED, + classifiers=[ + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: BSD License', + 'Programming Language :: Python :: 3', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + license='BSD-3', + keywords='tensorboard graph ascend plugin', +) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/test/test_plugin.py b/plugins/tensorboard-plugins/tb_graph_ascend/test/test_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..595e4c46894c8a5d8089818ef7042e825c59e7b0 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/test/test_plugin.py @@ -0,0 +1,451 @@ +import sys +import os +sys.path.append(os.path.pardir) +from unittest.mock import MagicMock, patch +import unittest +from flask import Flask +from server.plugin import GraphsPlugin +class TestGraphsPlugin(unittest.TestCase): + + def setUp(self): + """在每个测试之前初始化环境""" + fake_context = MagicMock() + + # 创建 GraphsPlugin 实例并传递 context + self.plugin = GraphsPlugin(context=fake_context) + self.plugin._current_file_path = "" # 初始化文件路径 + self.plugin.batch_id = '-1' # 设置为 -1 来触发 _process_subnode 中的判断逻辑 + self.plugin.step_id = '-1' # 设置为 -1 来触发 _process_subnode 中的判断逻辑 + + self.plugin._current_file_data = { + "NPU": { + "node": { + "npu_node_1": { + "matched_node_link": [] + } + } + }, + "Bench": { + "node": { + "bench_node_1": { + "matched_node_link": [] + } + } + }, + "match": [], + 'task': 'md5' + } + + self.app = Flask(__name__) + self.app.debug = True + self.client = self.app.test_client() + + # 创建模拟的 data_provider + mock_data_provider = MagicMock() + mock_data_provider.some_method.return_value = "some_value" # 根据需要设置模拟的方法 + + # 创建一个模拟的 context,并将 mock_data_provider 赋值给它 + context = MagicMock() + context.data_provider = mock_data_provider + + # 使用 context 创建 GraphsPlugin 实例 + self.instance = GraphsPlugin(context=context) + + def test_get_all_nodeName_with_valid_batch_and_step(self): + # 模拟 request.args + mock_request = MagicMock() + mock_request.args.get.return_value = '0' # 模拟 batch=0 和 step=0 + + # 构造 json_data + json_data = { + 'NPU': { + 'root': 'root_node', + 'node': { + 'root_node': { + 'micro_step_id': 0, + 'subnodes': ['subnode1', 'subnode2'] + }, + 'subnode1': {'micro_step_id': 0}, + 'subnode2': {'micro_step_id': 0}, + } + }, + 'Bench': { + 'node': { + 'bench1': {}, + 'bench2': {} + } + } + } + + # 调用 get_all_nodeName 方法 + npu_ids, bench_ids = self.plugin.get_all_nodeName(json_data, mock_request) + + # 验证返回的 npu_ids 和 bench_ids + self.assertEqual(npu_ids, ['subnode1', 'subnode2']) + self.assertEqual(bench_ids, ['bench1', 'bench2']) + + def test_dfs_collect_nodes_with_valid_batch_and_step(self): + # 模拟 request.args + mock_request = MagicMock() + side_effect_dict = { + 'batch': '0', + 'step': '0' + } + + # 设置 mock_request.args.get 的 side_effect + mock_request.args.get.side_effect = side_effect_dict.get + + # 构造 json_data + json_data = { + 'NPU': { + 'node': { + 'node1': {'micro_step_id': 0, 'step_id': 0, 'subnodes': []}, + 'node2': {'micro_step_id': 0, 'step_id': 0, 'subnodes': []}, + 'node3': {'micro_step_id': 1, 'step_id': 1, 'subnodes': []}, + 'node4': {'micro_step_id': 0, 'step_id': 1, 'subnodes': ['subnode1']} + } + } + } + + # 调用 dfs_collect_nodes 方法 + all_node_names = self.plugin.dfs_collect_nodes(json_data, mock_request) + + # 验证返回的 all_node_names + self.assertEqual(all_node_names, ['node1', 'node2']) + + def test_group_precision_set_with_even_number_of_elements(self): + # 测试正常的输入(偶数个元素) + precision_set = [1, 2, 3, 4, 5, 6] + expected_result = [[1, 2], [3, 4], [5, 6]] + + # 调用 group_precision_set 方法 + result = self.plugin.group_precision_set(precision_set) + + # 验证结果是否正确 + self.assertEqual(result, expected_result) + + def test_group_precision_set_with_odd_number_of_elements(self): + # 测试输入长度为奇数的情况 + precision_set = [1, 2, 3] + + # 验证是否抛出 ValueError 异常 + with self.assertRaises(ValueError) as context: + self.plugin.group_precision_set(precision_set) + + self.assertEqual(str(context.exception), 'The number of elements in precision_set is not even') + + def test_process_data_with_md5_mismatch(self): + # 测试 md5 不匹配的情况 + data = [] + data_set = { + 'npu_keys': ['npu_key_1'], + 'bench_keys': ['bench_key_1'], + 'input_data': {'npu_key_1': [1, 2, 3], 'md5': 'abcd'}, + 'output_data': {'bench_key_1': [1, 2, 3], 'md5': 'efgh'}, + 'precision_index': 0, + 'file_path': 'test_path', + 'data_type': 'test_type', + } + NPU_node = 'test_node' + + # 调用方法 + result = self.plugin.process_data(data, data_set, NPU_node) + + # 验证结果 + self.assertEqual(result, 0) + self.assertEqual(data_set['precision_index'], 0) + + def test_should_update_node_with_valid_batch_and_step(self): + # 模拟 json_data 和 subnode_id_data + subnode_id_data = {'micro_step_id': '-1', 'step_id': '-1', 'matched_node_link': ['N___subnode_1']} + subgraph = {'node': {}} + json_data = { + 'StepList': ['0', '1', '2'] # 测试 StepList 数据 + } + + prefix = 'N___' + subnode_id = 'subnode_1' + + # 调用 _process_subnode 方法 + self.plugin._process_subnode(subgraph, prefix, subnode_id, subnode_id_data, json_data) + + # 验证 subnode_id 是否更新 + self.assertIn(prefix + subnode_id, subgraph['node']) + self.assertEqual(subgraph['node'][prefix + subnode_id], subnode_id_data) + + def mock_json_get(self, *args): + """ 模拟 json_get 方法,返回不同层级的数据 """ + if len(args) == 4 and args[1] == "node": + # 返回节点的 matched_node_link 数据 + return self.plugin._current_file_data[args[0]][args[1]].get(args[2], {}).get('matched_node_link', []) + return None + + def test_should_update_node_with_invalid_batch_or_step(self): + # 测试 batch_id 和 step_id 为无效值时不会更新 + self.plugin.batch_id = '-1' + self.plugin.step_id = '-1' + + subnode_id_data = {'micro_step_id': '1', 'step_id': '1', 'matched_node_link': []} + subgraph = {'node': {}} + json_data = { + 'StepList': ['0', '1', '2'] + } + + prefix = 'B___' + subnode_id = 'subnode_1' + + # 调用 _process_subnode 方法 + self.plugin._process_subnode(subgraph, prefix, subnode_id, subnode_id_data, json_data) + + # 验证 subnode_id 是否被更新 + self.assertIn(prefix + subnode_id, subgraph['node']) + self.assertEqual(subgraph['node'][prefix + subnode_id], subnode_id_data) + + def test_update_matched_node_links(self): + subnode_id_data = { + 'matched_node_link': ['link_1', 'link_2'] + } + prefix = 'B___' + + # 调用 _update_matched_node_links 方法 + self.plugin._update_matched_node_links(subnode_id_data, prefix) + + # 验证 matched_node_link 是否被正确更新 + self.assertEqual(subnode_id_data['matched_node_link'], ['N___link_1', 'N___link_2']) + + def test_no_update_matched_node_links(self): + subnode_id_data = { + 'matched_node_link': ['link_1', 'link_2'] + } + prefix = 'N___' + + # 模拟常量 SETS + constants = MagicMock() + constants.SETS = { + 'Bench': ('Bench', 'B___', 'N___'), + 'NPU': ('NPU', 'N___', 'B___'), + 'B___': ('Bench', 'N___'), + 'N___': ('NPU', 'B___') + } + + # 不更新第一个 matched_node_link + subnode_id_data['matched_node_link'][0] = 'prefixlink_1' + + # 调用 _update_matched_node_links 方法 + self.plugin._update_matched_node_links(subnode_id_data, prefix) + + # 验证 linked node 是否正确更新 + self.assertEqual(subnode_id_data['matched_node_link'], ['B___prefixlink_1', 'B___link_2']) + + @patch('os.walk') # 模拟 os.walk + def test_get_run_dirs(self, mock_os_walk): + """测试 _get_run_dirs 方法""" + + # 设置模拟返回的文件夹和文件 + fake_logdir = os.path.join(os.getcwd(), "fake", "logdir") # 使用绝对路径 + mock_os_walk.return_value = [(fake_logdir, [], ["run1_tag1.vis", "run2_tag2.vis"])] + + # 设置文件大小返回值 + with patch('os.path.getsize', return_value=500): # 模拟文件小于限制 + run_tag_pairs = self.plugin._get_run_dirs() + + # 验证返回的 run_tag_pairs + # 使用 os.path.normpath 来确保路径在不同操作系统上被标准化 + expected_run_tag_pairs = [ + (os.path.normpath(os.path.join(fake_logdir)), 'run1_tag1'), + (os.path.normpath(os.path.join(fake_logdir)), 'run2_tag2') + ] + + self.assertEqual(run_tag_pairs, expected_run_tag_pairs) + + @patch('os.path.getsize') # 模拟 os.path.getsize + def test_get_run_dirs_with_large_file(self, mock_getsize): + """测试 _get_run_dirs 方法,当文件超过大小限制时""" + + # 模拟一个文件大于最大限制 + mock_getsize.return_value = 2000 * 1024 * 1024 # 文件超过 1GB + + # 使用 os.path.join 来构建路径,确保兼容 Windows 和 Linux + fake_logdir = os.path.join("fake", "logdir") + large_file = "large_file.vis" + + with patch('os.walk', return_value=[(fake_logdir, [], [large_file])]): + run_tag_pairs = self.plugin._get_run_dirs() + + # 验证文件被跳过,不会返回任何文件 + self.assertEqual(run_tag_pairs, []) # 文件被跳过,不会返回任何文件 + + def test_convert_to_protobuf_format(self): + """测试 _convert_to_protobuf_format 方法""" + # 模拟节点数据 + subgraph = { + 'node': { + 'npu_node_1': { + 'id': 'op_1', + 'node_type': 1, + 'matched_node_link': ['bench_node_1'], + 'data': { + 'precision_index': 10, + 'other_data': 'value' + }, + 'input_data': {}, + 'output_data': {}, + 'suggestions': {}, + 'subnodes': [], + 'stack_info': 'stack_1' + } + } + } + + # 调用方法 + protobuf_format = self.plugin._convert_to_protobuf_format(subgraph) + + # 验证 protobuf 格式是否正确 + self.assertIn('node {', protobuf_format) + self.assertIn('name: "npu_node_1"', protobuf_format) + self.assertIn('op: "op_1"', protobuf_format) + self.assertIn('precision_index: 10', protobuf_format) + self.assertIn('isLeaf: true', protobuf_format) + + @patch('json.load') # 模拟 json.load + def test_read_json_file_invalid(self, mock_json_load): + """测试 _read_json_file 方法,当文件无效时""" + # 设置模拟的文件路径 + mock_file_path = os.path.join("fake", "file.vis") # 使用 os.path.join 来构造路径 + + # 模拟 json.load 抛出异常 + mock_json_load.side_effect = Exception("Invalid JSON") + + # 使用模拟的路径读取文件 + result = self.plugin._read_json_file(mock_file_path) + + # 验证返回值是 None,并且日志中有错误消息 + self.assertIsNone(result) + + @patch('os.path.exists', return_value=False) + def test_load_json_file_not_found(self, mock_exists): + """测试 _load_json_file 方法,当文件不存在时""" + + # 使用 os.path.join 来确保路径的兼容性 + mock_file_path = os.path.join("fake", "file.vis") + mock_tag = "tag1" + + # 调用方法 + result = self.plugin._load_json_file(mock_file_path, mock_tag) + + # 验证返回值是否为 None + self.assertIsNone(result) + + # 验证 _current_file_path 是否为空 + self.assertEqual(self.plugin._current_file_path, "") + + def test_get_input_output_data(self): + NPU_node_data = { + 'input_data': {'input_1': 'data1', 'input_2': 'data2'}, + 'output_data': {'output_1': 'data3'} + } + Bench_node_data = { + 'input_data': {'input_1': 'dataA', 'input_2': 'dataB'}, + 'output_data': {'output_1': 'dataC', 'output_2': 'dataD'} + } + + # 调用方法 + NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data = self.instance.get_input_output_data(NPU_node_data, Bench_node_data) + + # 验证返回结果 + self.assertEqual(NPU_input_data, {'input_1': 'data1', 'input_2': 'data2'}) + self.assertEqual(Bench_input_data, {'input_1': 'dataA', 'input_2': 'dataB'}) + self.assertEqual(NPU_output_data, {'output_1': 'data3'}) + self.assertEqual(Bench_output_data, {'output_1': 'dataC', 'output_2': 'dataD'}) + + def test_calculate_min_length(self): + # 模拟输入输出数据 + NPU_input_data = {'input_1': 'data1', 'input_2': 'data2'} + Bench_input_data = {'input_1': 'dataA', 'input_2': 'dataB', 'input_3': 'dataC'} + NPU_output_data = {'output_1': 'data3'} + Bench_output_data = {'output_1': 'dataC', 'output_2': 'dataD'} + + # 调用方法 + input_min_length, output_min_length = self.instance.calculate_min_length(NPU_input_data, Bench_input_data, NPU_output_data, Bench_output_data) + + # 验证返回值 + self.assertEqual(input_min_length, 2) # 最小输入数据长度 + self.assertEqual(output_min_length, 1) # 最小输出数据长度 + + def test_get_keys(self): + # 模拟 NPU 和 Bench 数据 + NPU_data = { + 'input_data': {'input_1': 'data1', 'input_2': 'data2'}, + 'output_data': {'output_1': 'data3'} + } + Bench_data = { + 'input_data': {'input_1': 'dataA', 'input_2': 'dataB'}, + 'output_data': {'output_1': 'dataC', 'output_2': 'dataD'} + } + + # 调用方法 + npu_keys, bench_keys = self.instance.get_keys(NPU_data, Bench_data) + + # 验证返回值 + self.assertEqual(npu_keys, ['input_data', 'output_data']) + self.assertEqual(bench_keys, ['input_data', 'output_data']) + + def test_calculate_diff_and_relative_error(self): + # 模拟 npu_data 和 bench_data + npu_data = { + 'Max': 10.0, + 'Min': 1.0, + 'Mean': 5.5, + 'Norm': 8.0 + } + bench_data = { + 'Max': 12.0, + 'Min': 2.0, + 'Mean': 5.0, + 'Norm': 7.0 + } + + # 调用方法 + results = self.instance.calculate_diff_and_relative_error(npu_data, bench_data) + + # 验证返回结果 + self.assertEqual(results['NPU_Max'], 10.0) + self.assertEqual(results['Bench_Max'], 12.0) + self.assertEqual(results['Max_relative_err'], "16.666667%") + + self.assertEqual(results['NPU_Min'], 1.0) + self.assertEqual(results['Bench_Min'], 2.0) + self.assertEqual(results['Min_relative_err'], "50.000000%") + + self.assertEqual(results['NPU_Mean'], 5.5) + self.assertEqual(results['Bench_Mean'], 5.0) + self.assertEqual(results['Mean_relative_err'], "10.000000%") + + self.assertEqual(results['NPU_Norm'], 8.0) + self.assertEqual(results['Bench_Norm'], 7.0) + self.assertEqual(results['Norm_relative_err'], "14.285714%") + + def test_calculate_diff_and_relative_error_with_zero_bench_value(self): + # 模拟 npu_data 和 bench_data,其中 bench_data 有零值 + npu_data = { + 'Max': 10.0, + 'Min': 1.0, + 'Mean': 5.5, + 'Norm': 8.0 + } + bench_data = { + 'Max': 12.0, + 'Min': 0.0, # Bench Min 为 0,应该触发 "N/A" 处理 + 'Mean': 5.0, + 'Norm': 0.0 # Bench Norm 为 0,应该触发 "N/A" 处理 + } + + # 调用方法 + results = self.instance.calculate_diff_and_relative_error(npu_data, bench_data) + + # 验证返回结果 + self.assertEqual(results['Min_relative_err'], "N/A") + self.assertEqual(results['Norm_relative_err'], "N/A") + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/README.md b/profiler/README.md index 1669e3524e54bb78e6f4f09f597d2399196ff950..4db89899f1a8d1e22c7adb5556deb4f09067b725 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -26,7 +26,7 @@ with torch_npu.profiler.profile( profile_memory=True, with_stack=True, experimental_config=experimental_config, - schedule=torch.profiler.schedule(wait=10, warmup=0, active=1, repeat=1), + schedule=torch_npu.profiler.schedule(wait=10, warmup=0, active=1, repeat=1), on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./profiling_data") ) as prof: # 模型训练代码 @@ -81,7 +81,23 @@ ascend pytorch profiler数据目录结构如下: ## 工具安装 -性能工具的安装方式包括:**下载whl包安装**和**源代码编译安装**。 +性能工具的安装方式包括:**pip安装**、**下载whl包安装**和**源代码编译安装**。 + +### pip安装 + +```shell +pip install msprof-analyze +``` + +使用`pip install msprof-analyze==版本号`可安装指定版本的包,支持1.2.1及之后版本,版本号参见“**下载whl包安装**”。 + +pip命令会自动安装最新的包及其配套依赖。 + +提示如下信息则表示安装成功。 + +```bash +Successfully installed msprof-analyze-{version} +``` #### 下载whl包安装 @@ -89,12 +105,19 @@ ascend pytorch profiler数据目录结构如下: 请通过下表链接下载profiler工具whl包。 - | profiler版本 | 发布日期 | 下载链接 | 校验码 | - | ------------ | ---------- | ------------------------------------------------------------ | ------------------------------------------------------------ | - | 1.1.2 | 2024-07-12 | [msprof_analyze-1.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.2/msprof_analyze-1.1.2-py3-none-any.whl) | af62125b1f9348bf491364e03af712fc6d0282ccee3fb07458bc9bbef82dacc6 | - | 1.1.1 | 2024-06-20 | [msprof_analyze-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.1/msprof_analyze-1.1.1-py3-none-any.whl) | 76aad967a3823151421153d368d4d2f8e5cfbcb356033575e0b8ec5acea8e5e4 | - | 1.1.0 | 2024-05-28 | [msprof_analyze-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.0/msprof_analyze-1.1.0-py3-none-any.whl) | b339f70e7d1e45e81f289332ca64990a744d0e7ce6fdd84a8d82e814fa400698 | - | 1.0 | 2024-05-10 | [msprof_analyze-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.0/msprof_analyze-1.0-py3-none-any.whl) | 95b2f41c8c8e8afe4887b738c8cababcb4f412e1874483b6adae4a025fcbb7d4 | + | profiler版本 | 发布日期 | 下载链接 | 校验码 | + |------------|------------|-------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------| + | 1.3.0 | 2024-10-12 | [msprof_analyze-1.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.3.0/msprof_analyze-1.3.0-py3-none-any.whl) | 8b09758c6b5181bb656a95857c32852f898c370e7f1041e5a08e4f10d5004d48 | + | 1.2.5 | 2024-09-25 | [msprof_analyze-1.2.5-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.5/msprof_analyze-1.2.5-py3-none-any.whl) | aea8ae8deac07b5b4980bd2240da27d0eec93b9ace9ea9eb2e3a05ae9072018b | + | 1.2.4 | 2024-09-19 | [msprof_analyze-1.2.4-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.4/msprof_analyze-1.2.4-py3-none-any.whl) | 7c392e72c3347c4034fd3fdfcccb1f7936c24d9c3eb217e2cc05bae1347e5ab7 | + | 1.2.3 | 2024-08-29 | [msprof_analyze-1.2.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.3/msprof_analyze-1.2.3-py3-none-any.whl) | 354a55747f64ba1ec6ee6fe0f05a53e84e1b403ee0341ec40cc216dd25fda14c | + | 1.2.2 | 2024-08-23 | [msprof_analyze-1.2.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.2/msprof_analyze-1.2.2-py3-none-any.whl) | ed92a8e4eaf5ada8a2b4079072ec0cc42501b1b1f2eb00c8fdcb077fecb4ae02 | + | 1.2.1 | 2024-08-14 | [msprof_analyze-1.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.1/msprof_analyze-1.2.1-py3-none-any.whl) | 7acd477417bfb3ea29029dadf175d019ad3212403b7e11dc1f87e84c2412c078 | + | 1.2.0 | 2024-07-25 | [msprof_analyze-1.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.2.0/msprof_analyze-1.2.0-py3-none-any.whl) | 6a4366e3beca40b4a8305080e6e441d6ecafb5c05489e5905ac0265787555f37 | + | 1.1.2 | 2024-07-12 | [msprof_analyze-1.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.2/msprof_analyze-1.1.2-py3-none-any.whl) | af62125b1f9348bf491364e03af712fc6d0282ccee3fb07458bc9bbef82dacc6 | + | 1.1.1 | 2024-06-20 | [msprof_analyze-1.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.1/msprof_analyze-1.1.1-py3-none-any.whl) | 76aad967a3823151421153d368d4d2f8e5cfbcb356033575e0b8ec5acea8e5e4 | + | 1.1.0 | 2024-05-28 | [msprof_analyze-1.1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.1.0/msprof_analyze-1.1.0-py3-none-any.whl) | b339f70e7d1e45e81f289332ca64990a744d0e7ce6fdd84a8d82e814fa400698 | + | 1.0 | 2024-05-10 | [msprof_analyze-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/profiler/package/1.0/msprof_analyze-1.0-py3-none-any.whl) | 95b2f41c8c8e8afe4887b738c8cababcb4f412e1874483b6adae4a025fcbb7d4 | diff --git a/profiler/advisor/README.md b/profiler/advisor/README.md index c650f40b3ea8ef48b3c7644e279b00a1cb99f29a..c0148afc13dfd05851f8bd5f57d04eed2f89a7c5 100644 --- a/profiler/advisor/README.md +++ b/profiler/advisor/README.md @@ -1,11 +1,16 @@ # advisor -msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的PyThon场景性能数据进行分析,并输出性能调优建议(当前暂不支持对db格式文件分析)。 +msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的性能数据进行分析,并输出性能调优建议。 性能数据采集方法请参见《[性能分析工具](https://www.hiascend.com/document/detail/zh/mindstudio/70RC1/mscommandtoolug/mscommandug/atlasprofiling_16_0001.html)》。 ## 工具使用(命令行方式方式) +### 约束 + +- 不支持对db格式文件分析。 +- 不支持分析MindSpore场景采集的性能数据。 + ### 操作步骤 1. 参见《[性能工具](../README.md)》完成工具安装。建议安装最新版本。 @@ -36,11 +41,11 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 3. 查看结果。 - 分析结果输出相关简略建议到执行终端中,并生成`att_advisor_{timestamp}.html`和`att_advisor_{timestamp}.xlsx`文件供用户预览。 + 分析结果输出相关简略建议到执行终端中,并生成`mstt_advisor_{timestamp}.html`和`mstt_advisor_{timestamp}.xlsx`文件供用户预览。 - `att_advisor_{timestamp}.xlsx`文件内容与执行终端输出一致。 + `mstt_advisor_{timestamp}.xlsx`文件内容与执行终端输出一致。 - `att_advisor_{timestamp}.html`文件分析详见“**报告解析**”。 + `mstt_advisor_{timestamp}.html`文件分析详见“**报告解析**”。 执行终端输出示例如下: @@ -62,49 +67,67 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 #### 命令功能介绍 -| dimension | mode | 参数释义 | -| ---------- | -------------------------- | ---------------------------------------- | -| overall | overall_summary | 计算、通信、空闲等维度对性能数据进行拆解 | -| cluster | slow_rank | 慢卡识别 | -| | slow_link | 慢链路识别 | -| computing | aicpu | AI CPU调优 | -| | dynamic_shape_analysis | 识别动态Shape算子 | -| | block_dim_analysis | block dim算子调优 | -| | operator_no_bound_analysis | operator no bound | -| | graph | 融合算子图调优 | -| scheduling | timeline_fusion_ops | 亲和API替换调优 | -| | timeline_op_dispatch | 识别算子下发问题(路径3/路径5) | +msprof-analyze advisor命令行包含如下三个参数: - all - 总体性能瓶颈:包含上表中所有功能。 + 总体性能瓶颈:包含下表中所有功能。 - computation - 计算瓶颈:包含上表中computing功能。 + 计算瓶颈:包含下表中computing和Kernel compare功能。 - schedule - 调度瓶颈:包含上表中scheduling功能。 + 调度瓶颈:包含下表中scheduling和API compare功能。 + +下表中字段为advisor的完整功能点,由all、computation和schedule控制启动。 + +| dimension | mode | 参数释义 | +| ---------- |---------------------------------------| ------------------------------------ | +| overall | overall summary | 计算、通信、空闲等维度对性能数据进行拆解 | +| | environment_variable_analysis | 环境变量设置推荐 | +| cluster | slow rank | 慢卡识别 | +| | slow link | 慢链路识别 | +| computing | AICPU operator | AI CPU调优 | +| | Dynamic shape operator | 识别动态Shape算子 | +| | block dim | block dim算子调优 | +| | operator no bound | 算子瓶颈分析 | +| | fusion issue | 融合算子图调优 | +| | AI Core Frequency | AI Core算子降频分析 | +|communication| Packet analysis |通信小包检测 | +|| bandwidth contention analysis |通信计算带宽抢占检测 | +|| Communication retransmission analysis |通信重传检测 | +| scheduling | Affinity apis | 亲和API替换调优 | +| | Operator dispatch | 识别算子下发问题(路径3/路径5) | +| | SyncBatchNorm | BatchNorm同步检测 | +| | SynchronizeStream | 流同步检测 | +| | Slow dataloader | 异常dataloader检测 | +| | gc | 识别异常垃圾回收事件。需要Ascend PyTorch Profiler采集时开启experimental_config下的gc_delect_threshold功能 | +| memory | Memory | 识别异常的内存申请释放操作 | +| comparison | Kernel compare of Rank\* Step\* and Rank\* Step\* | 识别标杆和待比对性能数据的Kernel数据(无标杆场景是集群内部快慢卡的性能数据对比,有标杆场景是两个集群之间存在明显耗时差异的相同卡之间的性能数据对比) | +| | API compare of Rank\* Step\* and Rank\* Step\* | 识别标杆和待比对性能数据的API数据(无标杆场景是集群内部快慢卡的性能数据对比,有标杆场景是两个集群之间存在明显耗时差异的相同卡之间的性能数据对比) | + +集群场景时自动进行cluster和overall的environment_variable_analysis解析,单卡时自动进行overall解析。 #### 命令格式 - 总体性能瓶颈 ```bash - msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] + msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-o output_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] ``` - 计算瓶颈 ```bash - msprof-analyze advisor computation -d {profiling_path} [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] + msprof-analyze advisor computation -d {profiling_path} [-o output_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] ``` - 调度瓶颈 ```bash - msprof-analyze advisor schedule -d {profiling_path} [-cv cann_version] [-tv torch_version] [--debug] [-h] + msprof-analyze advisor schedule -d {profiling_path} [-o output_path] [-cv cann_version] [-tv torch_version] [--debug] [-h] ``` #### 参数介绍 @@ -113,28 +136,136 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 | ---------------------------------- | ------------------------------------------------------------ | -------- | | -d
--profiling_path | 性能数据文件或目录所在路径,Ascend PyTorch Profiler采集场景指定为`*_ascend_pt`性能数据结果目录,其他场景指定为`PROF_XXX`性能数据结果目录。建议通过Ascend PyTorch Profiler获取性能数据。
advisor依赖Profiling工具解析后的timeline数据(.json)、summary(.csv)数据以及info.json*文件,请确保指定的“profiling_path”目录下存在以上文件。 | 是 | | -bp
--benchmark_profiling_path | 基准性能数据所在目录,用于性能比对。性能数据通过Profiling工具采集获取。
**computation和schedule不支持该参数。** | 否 | -| -cv
--cann_version | 使用Profiling工具采集时对应的CANN软件版本,可通过在环境中执行如下命令获取其version字段,目前配套的兼容版本为“6.3.RC2”,“7.0.RC1”、“7.0.0”、“8.0.RC1”,此字段不填默认按“8.0.RC1”版本数据进行处理,其余版本采集的Profiling数据在分析时可能会导致不可知问题:`cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info` | 否 | +| -o
--output_path | 分析结果输出路径,完成advisor分析操作后会在该目录下保存分析结果数据。默认未配置,为当前目录。 | 否 | +| -cv
--cann_version | 使用Profiling工具采集时对应的CANN软件版本。目前配套的兼容版本为“6.3.RC2”,“7.0.RC1”、“7.0.0”、“8.0.RC1”,此字段不填默认按“8.0.RC1”版本数据进行处理,其余版本采集的Profiling数据在分析时可能会导致不可知问题。可通过在环境中执行如下命令获取其version字段:`cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info` | 否 | | -tv
--torch_version | 运行环境的torch版本,默认为1.11.0,支持torch1.11.0和torch2.1.0,当运行环境torch版本为其他版本如torch1.11.3时,可以忽略小版本号差异选择相近的torch版本如1.11.0。 | 否 | | -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。功能完善中,暂不建议使用。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。不建议使用。
**schedule不支持该参数。** | 否 | | --debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 | | -h,-H
--help | 在需要查询当前命令附属子命令或相关参数时,给出帮助建议。 | 否 | -### 报告解析 +### 报告解析(无标杆) + +无标杆是指执行msprof-analyze advisor时,未配置-bp参数,会根据性能数据中的computing time和free time判断是否进行kernel和API性能数据的对比,以慢卡数据为标杆数据,快卡数据为待比对数据。 -如下图所示,工具会从集群、单卡性能拆解、调度和计算等维度进行问题诊断并给出相应的调优建议。 +如下图所示,工具会从集群、单卡性能拆解、调度和计算等维度进行问题诊断并给出相应的调优建议。并通过红、黄、绿色块表示问题优先级,分别为High(高)、Medium(中)、Low(低)。 ![输入图片说明](./img/cluster.png) -cluster模块的分析包含快慢卡和快慢链路分析,仅识别问题,不提供调优建议。 -如下图示例,识别到当前训练任务的通信和下发(free较多说明存在任务下发存在问题)存在问题。 +#### overall模块的分析 + +overall模块仅识别问题,不提供调优建议。 + +- 无标杆单卡场景的overall模块的Environment Variable Issues是对环境变量的设置做出推荐。 + + ![env_var.png](./img/env_var.png) + + 上图中的环境变量详细介绍请参见[ACLNN_CACHE_LIMIT](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/apiref/envvar/envref_07_0031.html)和[HOST_CACHE_CAPACITY](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/developmentguide/appdevg/aclpythondevg/aclpythondevg_0045.html)。 + +- 无标杆单卡场景的overall模块的overall summary分析包含当前训练任务慢卡的性能拆解,按照计算、通信和下发三个维度进行耗时的统计,可以基于该分析识别到训练性能瓶颈是计算、通信还是下发问题,同样不提供调优建议。 + + ![输入图片说明](./img/overall_0.png) + + ![输入图片说明](./img/overall.png) + +- 无标杆集群场景的overall模块包含快慢卡和快慢链路分析。 + + ![cluster_1](./img/cluster_1.png) + + ![cluster_3](./img/cluster_3.png) + + ![cluster_4](./img/cluster_4.png) + + ![cluster_5](./img/cluster_5.png) + +#### comparison + +comparison模块内容如下图示例,识别标杆和待比对性能数据的Kernel和API数据,无标杆场景的comparison是集群内部快慢卡的性能数据对比。包括: + +- Kernel compare of Rank* Step* and Rank* Step*:Kernel的待比对总耗时、待比对平均耗时、待比对最大耗时、待比对最小耗时和待比对执行次数,以及标杆的对应数据,最后计算Diff Total Ratio(标杆总耗时/待比对总耗时)和Diff Avg Ratio(标杆平均耗时/待比对平均耗时)。 + + Diff Total Ratio和Diff Avg Ratio大于1则表示当前环境性能更优,小于1则表示当前环境有待优化,等于1则表示当前环境与标杆环境性能接近。 + + ![comparison2](./img/comparison2.png) + + 其中inf表示分母为0(未获取到待对比数据或待对比数据为0),None表示未获取到数据。 + +- Api compare of Rank* Step* and Rank* Step*:API的待比对总耗时、待比对API自身耗时(除去API调用的子API的耗时)、待比对平均耗时和待比对执行次数,以及标杆的对应数据,最后计算Diff Total Ratio(标杆总耗时/待比对总耗时)、Diff Self Ratio(标杆API自身耗时/待比对API自身耗时)、Diff Avg Ratio(标杆平均耗时/待比对平均耗时)和Diff Calls Ratio(标杆执行次数/待比对执行次数)。 + + Diff Total Ratio、Diff Self Ratio、Diff Avg Ratio和Diff Calls Ratio大于1则表示当前环境性能更优,小于1则表示当前环境有待优化,等于1则表示当前环境与标杆环境性能接近。 + + ![comparison3](./img/comparison3.png) + + 其中inf表示分母为0(未获取到待对比数据或待对比数据为0),None表示未获取到数据。 + +`mstt_advisor_{timestamp}.html`文件的comparison模块内容仅展示Kernel和API的Top 10条数据,详细数据需要查看`mstt_advisor_{timestamp}.xlsx`文件。 + +#### performance problem analysis模块的分析 + +performance problem analysis模块包含如下子模块。 + +memory模块分析内存的异常申请释放操作。 + +![memory](./img/memory.png) + +communication模块从通信维度进行分析,目前支持通信小算子检测。 + +![communication](./img/communication.png) + +上图中Zero1/Zero2/Zero3含义如下: + +- Zero1:每张NPU存储完整的一份梯度和模型参数,只有1/N优化器。每张NPU使用各自的数据做前向传播、反向传播,反向传播后使用all-reduce同步梯度到所有卡,使得每张卡有所有算子的梯度。每张卡根据梯度和1/N优化器更新1/N模型参数,再使用all-gather通信将优化器更新后的1/N模型参数发送给其它卡,因为每张卡有完整的一份模型参数需要更新。 +- Zero2:每张NPU存储完整的一份模型参数,只有1/N优化器和1/N梯度。每张NPU使用各自的数据做前向传播。反向传播后,计算出本卡的局部梯度,使用Reduce-Scatter通信聚合梯度,保证每张卡只保存1/N梯度。每张卡根据自己保持的1/N优化器和1/N梯度更新1/N模型参数,再使用all-gather通信将更新后的模型参数发送给其它卡,因为每张卡有完整的一份模型参数需要更新。 +- Zero3:每张NPU存储1/N模型参数、1/N优化器和1/N梯度。前向传播前,每张卡all-gather通信获取到完整的模型参数,再做前向传播计算,每用完一部分模型参数后就把它删除。反向传播开始前,每张卡all-gather通信获取到完整的模型参数,每用完一部分模型参数后就把它删除。使用reduce-scatter通信聚合梯度。每张卡根据自己保持的1/N优化器和1/N梯度更新1/N模型参数,由于每张卡只保存1/N模型参数,无需要将更新后的模型参数发送给其它卡。 + +通信重传检测分析,识别发生重传的通信域并提供调优建议。 + +如下图所示,识别到当前训练任务存在通信重传问题,并提供调优建议。 + +![cluster_2](./img/cluster_2.png) -![cluster_1](./img/cluster_1.png) +带宽抢占分析,检测计算和通信并发时,通信带宽被抢占的场景。 -overall模块的分析包含当前训练任务慢卡的性能拆解,按照计算、通信和下发三个维度进行耗时的统计,可以基于该分析识别到训练性能瓶颈是计算、通信还是下发问题,同样不提供调优建议。 +![bandwidth](./img/bandwidth.png) -![输入图片说明](./img/overall.png) +computation模块从device计算性能维度进行分析,能够识别AI CPU、计算bound、动态Shape、AI Core算子降频分析等问题并给出相应建议。此处不再详细展开,按照报告进行调优即可。示例如下: + +![computation_1](./img/computation_1.png) + +上图中torch_npu.npu.set_compile_mode接口介绍请参见[torch_npu.npu.set_compile_mode](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000880.html);AICPU算子替换样例可参考《[Samples of AI CPU Operator Replacement](https://gitee.com/ascend/mstt/blob/master/profiler/advisor/doc/Samples%20of%20AI%20CPU%20Operator%20Replacement.md)》。 + +当存在pp stage(流水线并行)时,computation会按stage分析,每个stage就是一个流水线切分,比如0\~7卡为stage-0、8\~15卡为stage-1。 + +![computation_2](./img/computation_2.png) + +dataloader模块包含Slow Dataloader Issues,主要检测异常高耗时的dataloader调用,并给出优化建议。 + +![dataloader](./img/dataloader.png) + +上图中的`pin_memory`(内存锁定)和`num_workers`(数据加载是子流程数量)参数为[数据加载优化](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/ptmoddevg/trainingmigrguide/performance_tuning_0019.html)使用。 + +schedule模块包GC Analysis、含亲和API、aclOpCompile、syncBatchNorm、SynchronizeStream等多项检测。 + +如下图示例,GC Analysis提示存在异常垃圾回收事件,用户可以通过有效的Python内存管理、使用`gc.set_threshold()`调整垃圾回收阈值、使用gc.disable()禁用gc等方法处理GC问题。 + +![gc](./img/gc.png) + +上图中`gc.set_threshold()`和`gc.disable()`函数说明如下: + +在Python中,gc模块提供了对垃圾回收器的控制。 + +- `gc.set_threshold(threshold0, thresholdl, threshold2)`:这个函数用于设置垃圾回收的阈值。垃圾回收器将所有对象分为三代(0代、1代和2代),每一代的对象在经历垃圾回收后会被移到下一代。`threshold0`控制第0代的垃圾回收频率,`threshold1`控制第1代的垃圾回收频率,`threshold2`控制第2代的垃圾回收频率。将`threshold0`设为0可以禁用垃圾回收。 +- `gc.disable ()`:这个函数用于禁用自动垃圾回收。调用`gc.disable ()`后,垃圾回收器将不会自动运行,直到手动调用`gc.enable()`。 + +如下图示例,Affinity API Issues提示存在可以替换的亲和API并给出对应的堆栈,用户可以根据堆栈找到需要修改的代码,并给出修改案例([API instruction](https://gitee.com/ascend/mstt/blob/master/profiler/advisor/doc/Samples%20of%20Fused%20Operator%20API%20Replacement.md))。 + +![schedule_3](./img/schedule_3.png) + +如下图示例,Synchronize Stream Issues提示存在耗时较多的同步流,并给出触发同步流的堆栈,需要根据堆栈来修改对应代码消除同步流。 + +![schedule_2](./img/schedule_2.png) + +上图中的ASCEND_LAUNCH_BLOCKING环境变量介绍请参见[ASCEND_LAUNCH_BLOCKING](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/apiref/envvar/envref_07_0050.html)。 -schedule模块包含亲和API、aclOpCompile、syncBatchNorm、SynchronizeStream等多项检测。 如下图示例,Operator Dispatch Issues提示需要在运行脚本的最开头添加如下代码用于消除aclOpCompile: ```python @@ -142,19 +273,43 @@ torch_npu.npu.set_compile_mode(jit_compile=False); torch_npu.npu.config.allow_internal_format = False ``` +以上接口介绍请参见[torch_npu.npu.set_compile_mode](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000880.html)和[torch_npu.npu.config.allow_internal_format](https://www.hiascend.com/document/detail/zh/Pytorch/60RC2/apiref/apilist/ptaoplist_000216.html)。 + ![输入图片说明](./img/schedule_1.png) -如下图示例,Synchronize Stream Issues提示存在耗时较多的同步流,并给出触发同步流的堆栈,需要根据堆栈来修改对应代码消除同步流。 +上图中aclopCompileAndExecute接口介绍请参见[aclopCompileAndExecute](https://www.hiascend.com/document/detail/zh/canncommercial/80RC22/apiref/appdevgapi/aclcppdevg_03_0243.html)。 -![schedule_2](./img/schedule_2.png) +### 报告解析(有标杆) -如下图示例,Affinity API Issues提示存在可以替换的亲和API并给出对应的堆栈,用户可以根据堆栈找到需要修改的代码,并给出修改案例(API instruction超链接)。 +有标杆是指执行msprof-analyze advisor时,配置-bp参数,指定基准性能数据进行比对。 -![schedule_3](./img/schedule_3.png) +有标杆单卡场景:不进行overall模块的分析,performance problem analysis模块与有标杆场景下的performance problem analysis模块结果一致。 -computation模块从device计算性能维度进行分析,能够识别AI CPU、计算bound、动态Shape等问题并给出相应建议。此处不再详细展开,按照报告进行调优即可。 +有标杆集群场景: -![computation_1](./img/computation_1.png) +- overall模块进行快慢卡和快慢链路分析,与无标杆集群场景一致,请参见“**报告解析(无标杆)** > **overall模块的分析**”。 +- 提供Environment Variable Issues,与无标杆单卡场景一致,请参见“**报告解析(无标杆)** > **overall模块的分析**”。 +- 有标杆集群场景同样提供comparison模块(无标杆场景是集群内部快慢卡的性能数据对比,有标杆场景是两个集群之间存在明显耗时差异的相同卡之间的性能数据对比)。 + +comparison模块内容如下图示例,识别标杆和待比对性能数据的Kernel和API数据,包括: + +- Kernel compare of Target and Benchmark:Kernel的待比对总耗时、待比对平均耗时、待比对最大耗时、待比对最小耗时和待比对执行次数,以及标杆的对应数据,最后计算Diff Total Ratio(标杆总耗时/待比对总耗时)和Diff Avg Ratio(标杆平均耗时/待比对平均耗时)。 + + Diff Total Ratio和Diff Avg Ratio大于1则表示当前环境性能更优,小于1则表示当前环境有待优化,等于1则表示当前环境与标杆环境性能接近。 + + ![comparison](./img/comparison.png) + + 其中inf表示分母为0(未获取到待对比数据或待对比数据为0),None表示未获取到数据。 + +- Api compare of Target and Benchmark:API的待比对总耗时、待比对API自身耗时(除去API调用的子API的耗时)、待比对平均耗时和待比对执行次数,以及标杆的对应数据,最后计算Diff Total Ratio(标杆总耗时/待比对总耗时)、Diff Self Ratio(标杆API自身耗时/待比对API自身耗时)、Diff Avg Ratio(标杆平均耗时/待比对平均耗时)和Diff Calls Ratio(标杆执行次数/待比对执行次数)。 + + Diff Total Ratio、Diff Self Ratio、Diff Avg Ratio和Diff Calls Ratio大于1则表示当前环境性能更优,小于1则表示当前环境有待优化,等于1则表示当前环境与标杆环境性能接近。 + + ![comparison1](./img/comparison1.png) + + 其中inf表示分母为0(未获取到待对比数据或待对比数据为0),None表示未获取到数据。 + +`mstt_advisor_{timestamp}.html`文件的comparison模块内容仅展示Kernel和API的Top 10条数据,详细数据需要查看`mstt_advisor_{timestamp}.xlsx`文件。 ## 工具使用(Jupyter Notebook方式) diff --git a/profiler/advisor/__init__.py b/profiler/advisor/__init__.py index e79018ed05c6d1cdeb56feaa6182f048e3c8e06f..8400fd5ecd1246eaee795cebfccfacc80a94f08c 100644 --- a/profiler/advisor/__init__.py +++ b/profiler/advisor/__init__.py @@ -12,6 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from profiler.advisor.interface.interface import Interface \ No newline at end of file diff --git a/profiler/advisor/advisor_backend/advice_base.py b/profiler/advisor/advisor_backend/advice_base.py index 35939bcea9c87fb09f2113bd19f77ea18ba54e34..97fd9ff4c04bbbbd1f630af92d112334dca00d5b 100644 --- a/profiler/advisor/advisor_backend/advice_base.py +++ b/profiler/advisor/advisor_backend/advice_base.py @@ -23,7 +23,7 @@ class AdviceBase: ADVICE = "advice" def __init__(self, collection_path: str): - self.collection_path = os.path.realpath(collection_path) + self.collection_path = os.path.abspath(collection_path) self.bottelneck = '' self.output_format_data = { self.DATA: [], diff --git a/profiler/advisor/advisor_backend/advice_factory/advice_factory.py b/profiler/advisor/advisor_backend/advice_factory/advice_factory.py index 639f4800cfe8c9acdc8fe7ea5f65a43fc8892b2b..1b4b0c1be4ba466588469728e5aaa0cc3a1422ac 100644 --- a/profiler/advisor/advisor_backend/advice_factory/advice_factory.py +++ b/profiler/advisor/advisor_backend/advice_factory/advice_factory.py @@ -19,7 +19,7 @@ from common_func.path_manager import PathManager class AdviceFactory: def __init__(self, collection_path: str): - self.collection_path = os.path.realpath(collection_path) + self.collection_path = os.path.abspath(collection_path) @staticmethod def run_advice(self, advice: str, kwargs: dict): diff --git a/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py b/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py index e9be4675963a9cd48da3b4cd91ee646f8e82468b..2cc1eebebc921510e44e8869bd70c423995efeb0 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py +++ b/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py @@ -14,11 +14,14 @@ # limitations under the License. import os +import logging from abc import abstractmethod from common_func.constant import Constant from advice_base import AdviceBase from cluster_analysis import Interface +logger = logging.getLogger() + class ClusterAdviceBase(AdviceBase): def __init__(self, collection_path: str): @@ -37,11 +40,11 @@ class ClusterAdviceBase(AdviceBase): """ for file in os.listdir(self.collection_path): if file == 'cluster_analysis_output': - print("[INFO]Cluster has been analyzed " - "because of the existence of cluster analysis output directory.") - print("[INFO]Skip Cluster analyze backend.") + logger.info("Cluster has been analyzed " + "because of the existence of cluster analysis output directory.") + logger.info("Skip Cluster analyze backend.") return - print("[INFO] cluster analysis is in the process, please wait...") + logger.info("cluster analysis is in the process, please wait...") self.cluster_analyze() def cluster_analyze(self): diff --git a/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py b/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py index 6fa83c765f5fe1f4ac20dcc62895fe0450e338ce..e7b334bbfe2cea02214e2900094c2c005f463432 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py +++ b/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py @@ -5,7 +5,7 @@ from common_func.constant import Constant from common_func_advisor.constant import Constant as AdvisorConstant from cluster_advice.cluster_advice_base import ClusterAdviceBase from cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor - +from profiler.cluster_analyse.common_func.file_manager import FileManager class KernelClusterAdvice(ClusterAdviceBase): COLUMNS_TO_GROUP = ["Name", "Input Shapes", "Input Data Types", "Output Shapes"] @@ -32,6 +32,7 @@ class KernelClusterAdvice(ClusterAdviceBase): kernel_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.KERNEL_DETAILS_CSV) if kernel_file: # 判断csv文件大小 + FileManager.check_file_size(kernel_file) PathManager.check_path_readable(kernel_file) # 读取CSV文件 df_temp = pd.read_csv(kernel_file) diff --git a/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py b/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py index f8a625242f3939602cbb7b8391cd8062e21fe01b..8d299326236461614afb152fbcdb62cc0fb61d94 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py +++ b/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py @@ -96,6 +96,8 @@ class SlowLinkAdvice(ClusterAdviceBase): def produce_bottleneck(self, link_type: str): data_list = [rank_dict.get(link_type, 0) for rank_id, rank_dict in self.rank_bw_dict.items()] + if len(data_list) == 0: + raise ValueError("Cannot calculate avg_bw, data_list is empty!") avg_bw = round(sum(data_list) / len(data_list), 3) if avg_bw == 0: return diff --git a/profiler/advisor/advisor_backend/common_func_advisor/trace_view_preprocessor.py b/profiler/advisor/advisor_backend/common_func_advisor/trace_view_preprocessor.py index 7b9baa32d9423a46bf93d563a6fabbbbb652aaf8..6eced27efd203e2e7b23e47cd33ab71ce8c4e240 100644 --- a/profiler/advisor/advisor_backend/common_func_advisor/trace_view_preprocessor.py +++ b/profiler/advisor/advisor_backend/common_func_advisor/trace_view_preprocessor.py @@ -104,7 +104,7 @@ class TraceViewPreProcessor: check whether op is hcom send or recv op """ # eg: hcom_BatchSendRecv__101_0_1 - p1 = re.compile(r'hcom_\w+SendRecv__\d+') + p1 = re.compile(r'^hcom_\w+SendRecv__\d+') # eg: hcom_send__101_0_1 p2 = re.compile(r'hcom_send__\d+') # eg: hcom_receive__101_0_1 diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py b/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py index c85c14d618ceda199c9c376abc27a3581eed97b8..8ae109856a3b8719d8d7fef2ff15ca7a1270eccc 100644 --- a/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py +++ b/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py @@ -18,6 +18,7 @@ import multiprocessing import pandas as pd import numpy as np +from common_func.path_manager import PathManager from common_func_advisor.constant import Constant from .op_perf import OpPerfFactory @@ -27,6 +28,7 @@ class CSVAnalyzer: self._path = path def process(self): + PathManager.check_path_readable(self._path) df = pd.read_csv(self._path, dtype={"Start Time(us)": str}) # 分析是否存在可融合的算子 op_type_list = df["Type"].tolist() diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py b/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py index caff1c792c2171c33a4dd876b0741d6c215c5766..48522cf55a4cfb3f89083c3ac69ec7b22b295195 100644 --- a/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py +++ b/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC +import os import multiprocessing import pandas as pd +from common_func.path_manager import PathManager from compute_advice.compute_advice_base import ComputeAdviceBase from compute_advice.npu_fused.op_perf import OpPerfFactory from common_func_advisor.constant import Constant @@ -34,6 +36,7 @@ class NpuSlowAdvice(ComputeAdviceBase, ABC): @staticmethod def save_to_excel(data: pd.DataFrame, file_path: str) -> None: + PathManager.check_path_writeable(os.path.dirname(file_path)) writer = pd.ExcelWriter(file_path, engine="xlsxwriter", mode="w") data.index.name = Constant.TITLE.INDEX data.to_excel(writer, index=True, sheet_name=NpuSlowAdvice.OP_PERF_SHEET) @@ -73,6 +76,7 @@ class NpuSlowAdvice(ComputeAdviceBase, ABC): return self.data def process(self): + PathManager.check_path_readable(self.kernel_details_path) self.data = pd.read_csv(self.kernel_details_path, dtype={"Start Time(us)": str}) # 去除末尾的\t分隔符 self.data["Start Time(us)"] = self.data["Start Time(us)"].apply(lambda x: x[:-1]) diff --git a/profiler/advisor/advisor_backend/interface.py b/profiler/advisor/advisor_backend/interface.py index 3e20c26d4d7bb000b20c28439b28ddf4811f057f..deb68822ec4ac025e6cf647cd031415618cc415e 100644 --- a/profiler/advisor/advisor_backend/interface.py +++ b/profiler/advisor/advisor_backend/interface.py @@ -30,7 +30,7 @@ from advisor_backend.advice_factory.overall_advice_factory import OverallAdviceF class Interface: def __init__(self, collection_path: str): - self.collection_path = os.path.realpath(collection_path) + self.collection_path = os.path.abspath(collection_path) self._factory_controller = FactoryController(collection_path) def get_data(self: any, mode: str, advice: str, **kwargs): @@ -50,7 +50,7 @@ class FactoryController: } def __init__(self, collection_path: str): - self.collection_path = os.path.realpath(collection_path) + self.collection_path = os.path.abspath(collection_path) self.temp_input_path = None def create_advice_factory(self, mode: str, input_path: str): diff --git a/profiler/advisor/analyzer/analyzer_controller.py b/profiler/advisor/analyzer/analyzer_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..28460dfa4e0571644afbc84c58d7c7756db2e9ff --- /dev/null +++ b/profiler/advisor/analyzer/analyzer_controller.py @@ -0,0 +1,921 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +import json +import sys +import os +import platform +import multiprocessing as mp +from multiprocessing import Manager +from pathlib import Path + +import psutil + +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "compare_tools")) +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "cluster_analyse")) + +from profiler.advisor.analyzer.cluster.slow_rank_analyzer import SlowRankAnalyzer +from profiler.advisor.analyzer.cluster.slow_link_analyzer import SlowLinkAnalyzer +from profiler.advisor.analyzer.computation.pp_stage_computation_analyzer import PPStageComputationAnalyzer +from profiler.advisor.analyzer.overall.overall_summary_analyzer import OverallSummaryAnalyzer +from profiler.advisor.config.config import Config +from profiler.advisor.common import constant as const +from profiler.advisor.common.analyzer_scopes import SupportedScopes +from profiler.advisor.common.async_analysis_status import AsyncAnalysisStatus +from profiler.advisor.common.enum_params_parser import EnumParamsParser +from profiler.advisor.utils.utils import Timer, safe_index_value, safe_division, safe_index, convert_to_int +from profiler.advisor.interface.interface import Interface +from profiler.cluster_analyse.cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor +from profiler.prof_common.path_manager import PathManager +from profiler.compare_tools.compare_backend.utils.constant import Constant as CompareConstant + +# 以spawn模式启动多进程,避免fork主进程资源。如果主进程逻辑较为复杂,fork可能会导致异常。 +mp.set_start_method("spawn", force=True) +logger = logging.getLogger() + + +class AsyncParams: + """处理用户异步请求的输入参数,包括cli arguments和环境变量两类参数.""" + user_valid_arguments = {} + user_valid_envs = {} + user_non_enum_params = {} + user_invalid_values = [] + user_total_params = {} + + @staticmethod + def parse_async_list_params(key, value, option_values, key_type, value_type): + if isinstance(value, list): + value_list = value + else: + value_list = [_.strip(" ") for _ in str(value).split(",")] + + if sorted(value_list) not in [sorted(option) for option in option_values]: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = ",".join(value_list) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = value_list + + @staticmethod + def parse_async_int_params(key, value, option_values, key_type, value_type): + if convert_to_int(value) not in option_values: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = str(convert_to_int(value)) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = convert_to_int(value) + + @staticmethod + def parse_async_str_params(key, value, option_values, key_type, value_type): + if str(value) not in option_values: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = str(value) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = str(value) + + @staticmethod + def parse_async_boolean_params(key, value, option_values, key_type, value_type): + + if str(value).lower() not in ["true", "false"]: + AsyncParams.user_invalid_values.append( + {"key": key, "invalid value": value, "optional values": option_values, + "required value type": value_type}) + return + + if key_type == EnumParamsParser.ENVS: + AsyncParams.user_valid_envs[key.upper()] = str(value) + elif key_type == EnumParamsParser.ARGUMENTS: + AsyncParams.user_valid_arguments[key] = str(value).lower() == "true" + + @staticmethod + def parse_params(user_async_params): + params_parser = EnumParamsParser() + valid_env_keys = [key.lower() for key in params_parser.get_envs_keys()] + valid_arg_keys = [key.lower() for key in params_parser.get_arguments_keys()] + + for key, value in user_async_params.items(): + key = key.lower() + if key not in valid_env_keys + valid_arg_keys: + AsyncParams.user_non_enum_params[key] = value + continue + + if key in valid_env_keys: + # 环境变量均大写,异步调用入参到analyzer controller时支持用户使用小写配置环境变量 + option_values = params_parser.get_options(key.upper()) + value_type = params_parser.get_value_type(key.upper()) + key_type = params_parser.ENVS + else: + option_values = params_parser.get_options(key) + value_type = params_parser.get_value_type(key) + key_type = params_parser.ARGUMENTS + + if hasattr(AsyncParams, f"parse_async_{value_type}_params"): + getattr(AsyncParams, f"parse_async_{value_type}_params")(key, value, option_values, key_type, + value_type) + + AsyncParams.user_total_params["async_analysis_env"] = AsyncParams.user_valid_envs + AsyncParams.user_total_params.update(AsyncParams.user_valid_arguments) + AsyncParams.user_total_params.update(AsyncParams.user_non_enum_params) + + +class AnalyzerController: + CLUSTER_RANK_THRESHOLD = 2 + SDMA_SUPPORT_SCOPES = [SupportedScopes.BANDWIDTH_CONTENTION_DETECTION] + RDMA_SUPPORT_SCOPES = [SupportedScopes.PACKET] + COMMUNICATION_MAPPING = { + SlowLinkAnalyzer.SDMA: SDMA_SUPPORT_SCOPES, + SlowLinkAnalyzer.RDMA: RDMA_SUPPORT_SCOPES + } + + def __init__(self): + self.dimensions = Interface.all_dimension + self.kwargs = {} + self.slow_rank_analyzer = None + self.slow_link_analyzer = None + self.cluster_local_data_map = {} + self.default_rank_id = None + self.rank_id_map = {} + self._is_cluster = False + self.analysis_process_resp = Manager().dict() + + @staticmethod + def _set_analysis_process_priority(pid): + # 将分析进程优先级设置为最低,避免因为分析进程阻塞其他任务进程,unix上19表示最低优先级 + unix_process_lowest_priority = 19 + windows_platform = "windows" + linux_platform = "linux" + p = psutil.Process(pid) + if platform.system().lower() == windows_platform: + p.nice(psutil.BELOW_NORMAL_PRIORITY_CLASS) + elif platform.system().lower() == linux_platform: + p.nice(unix_process_lowest_priority) + + @staticmethod + def _check_profiling_path_valid(profiling_path): + PathManager.input_path_common_check(profiling_path) + + if not Path(profiling_path).exists(): + logger.error("Profiling path is not existed. Invalid profiling path: %s", profiling_path) + return False + + return True + + @staticmethod + def _whether_include_mindspore_prof(profiling_path): + # 暂不支持Mindspore数据,支持后可删除该限制 + ASCEND_MS = "ascend_ms" + + has_ascend_ms_dirs = False + for root, dirs, _ in os.walk(profiling_path): + if root.endswith(ASCEND_MS): + has_ascend_ms_dirs = True + break + for dir_name in dirs: + if dir_name.endswith(ASCEND_MS): + has_ascend_ms_dirs = True + break + if has_ascend_ms_dirs: + break + + if has_ascend_ms_dirs: + logger.error("Advisor does not support data from MindSpore now, existing dirs end with 'ascend_ms'") + return True + + return False + + @staticmethod + def _get_step_rank_for_cluster_statistic_diff(target_cluster_statistic_data, benchmark_cluster_statistic_data, + headers, dimension, get_max=False): + if dimension not in headers: + logger.error("Error dimension %s for cluster statistics data, optionals are %s.", dimension, headers) + return None, None, None + + dimension_index = safe_index_value(headers, dimension) + diff_record = [] + # 对比目标profiling和benchmark profiling 每张卡的计算和下发和带宽,取计算、下发、带宽差异最大的卡进行下一步分析 + for target_row_data, benchmark_row_data in zip(target_cluster_statistic_data, benchmark_cluster_statistic_data): + target_data = safe_index(target_row_data, dimension_index) + benchmark_data = safe_index(benchmark_row_data, dimension_index) + + if not isinstance(target_data, (int, float)) or not isinstance(benchmark_data, (int, float)): + continue + diff_record.append(target_data - benchmark_data) + + if SlowRankAnalyzer.compute_max_gap_ratio(diff_record, safe_division(sum(diff_record), len( + diff_record))) < SlowRankAnalyzer.RATIO_THRESHOLD: + return None, None, None + + value = max(diff_record) if get_max else min(diff_record) + value_index = safe_index_value(diff_record, value) + + step_value_index = safe_index_value(headers, "step") + rank_id_value_index = safe_index_value(headers, "rank_id") + + step = safe_index(safe_index(target_cluster_statistic_data, value_index, []), step_value_index) + benchmark_step = safe_index(safe_index(benchmark_cluster_statistic_data, value_index, []), step_value_index) + target_rank_id = safe_index(safe_index(target_cluster_statistic_data, value_index, []), rank_id_value_index) + benchmark_rank_id = safe_index(safe_index(benchmark_cluster_statistic_data, value_index, []), + rank_id_value_index) + + if target_rank_id != benchmark_rank_id: + logger.error( + "Rank ids of target profiling must keep the same as benchmark profiling, skip cluster comparison") + return None, None, None + + return step, benchmark_step, target_rank_id + + @staticmethod + def _init_async_analysis_env(kwargs): + envs = kwargs.get("async_analysis_env", {}) + for key, value in envs.items(): + os.environ[key] = value + + def format_async_analysis_params(self, pid, async_resp, dimensions, kwargs): + + AsyncParams.parse_params(kwargs) + dimensions = AsyncParams.user_total_params.get("analysis_dimensions") or dimensions + + if AsyncParams.user_invalid_values: + error_msg = "Got invalid arguments as follows: \n " + for index, invalid_value in enumerate(AsyncParams.user_invalid_values): + error_msg += f"{index + 1}. Key '{invalid_value.get('key')}', " \ + f"invalid value '{invalid_value.get('invalid value')}', " \ + f"optional valid values '{invalid_value.get('optional values')}', " \ + f"required value type '{invalid_value.get('required value type')}'.\n " + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + raise ValueError(error_msg) + + logger.warning("User parameters for async analysis is as follows:\n %s", + json.dumps(AsyncParams.user_total_params, indent=4)) + return dimensions, AsyncParams.user_total_params + + def do_analysis(self, dimensions, **kwargs): + pid = os.getpid() + resp = {"id": pid} + output_path = kwargs.get("output_path") + + AnalyzerController._set_analysis_process_priority(pid) + if kwargs.get("is_async_analysis"): + del kwargs["is_async_analysis"] + dimensions, kwargs = self.format_async_analysis_params(pid, resp, dimensions, kwargs) + AnalyzerController._init_async_analysis_env(kwargs) + + try: + if output_path: + + PathManager.check_input_directory_path(output_path) + if os.path.exists(output_path): + PathManager.check_path_owner_consistent(output_path) + else: + PathManager.make_dir_safety(output_path) + + Config().set_config("_work_path", output_path) + Config().set_log_path(f"mstt_advisor_{Timer().strftime}.xlsx") + + self._do_analysis(dimensions, pid=pid, async_resp=resp, **kwargs) + except Exception as e: + self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.INNER_ERROR_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED, error_msg=str(e)) + logger.error(e) + raise RuntimeError("Do analysis error.") from e + + def async_do_analysis(self, dimensions, **kwargs): + # 异步分析,用于部署服务,通过接口查询异步作业状态 + kwargs["is_async_analysis"] = True + + async_analysis_process = mp.Process(target=self.do_analysis, args=(dimensions,), kwargs=kwargs, + name="Async advisor performance analysis") + async_analysis_process.start() + self._update_analysis_process_resp(async_analysis_process.pid, {"id": async_analysis_process.pid}, + status_code=AsyncAnalysisStatus.NON_FAILED_STATUS_CODE, + status=AsyncAnalysisStatus.ANALYZING) + return async_analysis_process + + def get_response_by_pid(self, pid): + def _is_pid_exists(pid): + try: + psutil.Process(pid) + return True + except psutil.NoSuchProcess: + return False + + pid_not_exist_response = dict(id=pid, status_code=AsyncAnalysisStatus.NOT_FOUND_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED, + error_msg="The advisor task id does not exist") + if pid not in self.analysis_process_resp: + return pid_not_exist_response + + response = self.analysis_process_resp.get(pid) + if response.get("status") not in [AsyncAnalysisStatus.FAILED, + AsyncAnalysisStatus.SUCCESS] and not _is_pid_exists(pid): + return pid_not_exist_response + return response + + def single_rank_analysis(self, profiling_path, benchmark_profiling_path=None): + job_list = [] + + profiling_path = self._get_profiling_path_by_rank(profiling_path) + benchmark_profiling_path = self._get_profiling_path_by_rank(benchmark_profiling_path) + + # 单卡场景无集群分析 + for dim in [Interface.CLUSTER]: + if dim in self.dimensions: + self.dimensions.remove(dim) + + for dimension in self.dimensions: + dimension_analysis_func_name = f"{dimension}_analysis" + if not hasattr(self, dimension_analysis_func_name): + continue + logger.info("Start %s analysis", dimension) + job_list += getattr(self, dimension_analysis_func_name)(profiling_path) + + if benchmark_profiling_path: + # kernel/api 比对 + compare_profiling_list = [ + dict(profiling_path=profiling_path, benchmark_profiling_path=benchmark_profiling_path, + compare_mode=CompareConstant.KERNEL_COMPARE), + dict(profiling_path=profiling_path, benchmark_profiling_path=benchmark_profiling_path, + compare_mode=CompareConstant.API_COMPARE) + ] + + job_list += self._profiling_comparison(compare_profiling_list) + else: + self.overall(profiling_path) + + return job_list + + def do_cluster_analysis(self, profiling_path, benchmark_profiling_path=None): + job_list = [] + + # 单集群profiling分析:下发、通信、计算、显存/内存 + for dimension in self.dimensions: + dimension_analysis_func_name = f"cluster_{dimension}_analysis" + if not hasattr(self, dimension_analysis_func_name): + continue + logger.info("Start cluster %s analysis", dimension) + job_list += getattr(self, dimension_analysis_func_name)(profiling_path) + + self.overall(profiling_path) + + if benchmark_profiling_path: + # 两个集群profiling比对分析 + job_list += self._cluster_profiling_comparison(profiling_path, benchmark_profiling_path) + return job_list + + def overall(self, profiling_path): + from profiler.advisor.analyzer.overall.environment_variable_analyzer import EnvironmentVariabelAnalyzer + env_analyzer = EnvironmentVariabelAnalyzer(profiling_path) + env_analyzer.optimize() + + if self._is_cluster: + self.slow_rank_analyzer.optimize(template_key=Interface.OVERALL) + self.slow_link_analyzer.optimize(template_key=Interface.OVERALL) + else: + overall_analyzer = OverallSummaryAnalyzer(profiling_path) + overall_analyzer.optimize() + + def schedule_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, benchmark_step=None, + rank=None, **kwargs): + # 任意单卡的下发分析 + + input_kwargs = copy.deepcopy(self.kwargs) + job_list = [] + + input_kwargs["profiling_path"] = profiling_path + input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path + input_kwargs["step"] = step + input_kwargs["benchmark_step"] = benchmark_step + input_kwargs["rank"] = rank + + for dimension in [Interface.SCHEDULE]: + for scope in Interface.get_scope(dimension): + interface = Interface(**input_kwargs) + job_list.append((dimension, scope, interface, input_kwargs)) + return job_list + + def computation_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, + benchmark_step=None, stage=None, **kwargs): + # 任意单卡的计算分析 + + input_kwargs = copy.deepcopy(self.kwargs) + input_kwargs["profiling_path"] = profiling_path + input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path + input_kwargs["step"] = step + input_kwargs["benchmark_step"] = benchmark_step + input_kwargs["stage"] = stage + input_kwargs["rank"] = kwargs.get("rank") + job_list = [] + + for dimension in [Interface.COMPUTATION]: + for scope in Interface.get_scope(dimension): + if scope == SupportedScopes.STAGE_COMPUTE: + continue + interface = Interface(**input_kwargs) + job_list.append((dimension, scope, interface, input_kwargs)) + return job_list + + def memory_analysis(self, profiling_path, benchmark_profiling_path=None, step=None, benchmark_step=None, rank=None): + # 任意单卡的内存分析 + + input_kwargs = copy.deepcopy(self.kwargs) + job_list = [] + + input_kwargs["profiling_path"] = profiling_path + input_kwargs["benchmark_profiling_path"] = benchmark_profiling_path + input_kwargs["step"] = step + input_kwargs["benchmark_step"] = benchmark_step + input_kwargs["rank"] = rank + + for dimension in [Interface.MEMORY]: + for scope in Interface.get_scope(dimension): + interface = Interface(**input_kwargs) + job_list.append((dimension, scope, interface, input_kwargs)) + return job_list + + def communication_analysis(self, profiling_path, benchmark_profiling_path=None, **kwargs): + + job_list = [] + supported_trans_type = [SlowLinkAnalyzer.SDMA, SlowLinkAnalyzer.RDMA] + step = kwargs.get("step", None) + benchmark_step = kwargs.get("benchmark_step", None) + bandwidth_type = kwargs.get("bandwidth_type", None) + scope = kwargs.get("scope", None) + if bandwidth_type is not None and bandwidth_type not in supported_trans_type: + logger.error("Error transit type %s, optionals are %s", bandwidth_type, supported_trans_type) + return job_list + + job_list += self._communication_analysis(profiling_path=profiling_path, + benchmark_profiling_path=benchmark_profiling_path, + step=step, benchmark_step=benchmark_step, + scope=scope, bandwidth_type=bandwidth_type) + + return job_list + + def cluster_schedule_analysis(self, profiling_path): + # 目标集群profiling数据下发分析,不包含两个集群profiling数据的比对分析 + + job_list = [] + global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.FREE) + + info_msg = "For cluster schedule analysis, " + slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") + if slow_rank_id is not None: + info_msg += f"maximum free for rank {slow_rank_id}" + else: + slow_rank_id = self.default_rank_id + info_msg += f"no slow rank with free time, analysis for default rank {slow_rank_id}" + + fast_rank_id = global_step_rank.get("minimum", {}).get("rank_id") + + slow_step = global_step_rank.get("maximum", {}).get("step") + fast_step = global_step_rank.get("minimum", {}).get("step") + + if slow_step is not None: + info_msg += f" and step {slow_step}" + logger.info(info_msg) + + kwargs = dict(profiling_path=self._get_profiling_path_by_rank(profiling_path, slow_rank_id), + benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, fast_rank_id), + step=slow_step, benchmark_step=fast_step, + rank=slow_rank_id, benchmark_rank=fast_rank_id, + compare_mode=CompareConstant.API_COMPARE) + + job_list += self.schedule_analysis(**kwargs) + + rank_id_valid = slow_rank_id is not None and fast_rank_id is not None and fast_rank_id != slow_rank_id + if self.kwargs.get("benchmark_profiling_path") is None and rank_id_valid: + # 当用户指定benchmark profiling path时,不进行目标集群profiling的内部快慢卡对比 + logger.info("Enable schedule comparison of fast and slow rank/step") + job_list += self._profiling_comparison([kwargs]) + return job_list + + def cluster_communication_analysis(self, profiling_path): + job_list = [] + + for dimension in [Interface.COMMUNICATION]: + for scope in Interface.get_scope(dimension): + analyzer_class = Interface.get_analyzer(dimension, scope) + if hasattr(analyzer_class, "requires_cluster_dataset") and getattr(analyzer_class, + "requires_cluster_dataset"): + + # 如果不依赖数据集,或者依赖的是ClusterDataset,则不用根据带宽确定需要分析的特定rank + kwargs = copy.deepcopy(self.kwargs) + kwargs["profiling_path"] = profiling_path + interface = Interface(**kwargs) + job_list.append((dimension, scope, interface, kwargs)) + else: + # 非ClusterDataset场景,需要根据带宽大小分析特定的rank + for bandwidth_type in [SlowLinkAnalyzer.SDMA, SlowLinkAnalyzer.RDMA]: + global_step_rank = self.slow_link_analyzer.get_global_step_rank(bandwidth_type) + # 获取带宽最小的卡进行分析 + target_rank_id = global_step_rank.get("minimum", {}).get("rank_id") + if target_rank_id is None: + target_rank_id = self.default_rank_id + step = global_step_rank.get("minimum", {}).get("step") + analysis_profiling_path = self._get_profiling_path_by_rank(profiling_path, target_rank_id) + + info_msg = f"Minimum {bandwidth_type} bandwidth for rank {target_rank_id} " + if step: + info_msg += f"and step {step}" + logger.info(info_msg) + + job_list += self.communication_analysis(analysis_profiling_path, step=step, + bandwidth_type=bandwidth_type, scope=scope) + + return job_list + + def cluster_computation_analysis(self, profiling_path): + # 目标集群profiling数据计算分析,不包含两个集群profiling数据的比对分析;如果有pp stage,则对不同stage进行计算分析 + + job_list = [] + global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.COMPUTE) + stage_step_rank = self.slow_rank_analyzer.get_stage_step_rank(SlowRankAnalyzer.COMPUTE) + + if stage_step_rank: + job_list = self._stage_computation_analysis(profiling_path, stage_step_rank, job_list) + else: + job_list = self._global_computation_analysis(profiling_path, global_step_rank, job_list) + return job_list + + def cluster_memory_analysis(self, profiling_path): + # 目标集群profiling数据内存分析,当前memory识别的两个算子,导致的问题都是大的free,因此选择FREE最慢的卡进行分析 + + job_list = [] + global_step_rank = self.slow_rank_analyzer.get_global_step_rank(SlowRankAnalyzer.FREE) + + info_msg = "For cluster memory analysis, " + slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") + if slow_rank_id is not None: + info_msg += f"maximum free for rank {slow_rank_id}" + else: + slow_rank_id = self.default_rank_id + info_msg += f"no slow rank with free time, analysis for default rank {slow_rank_id}" + + slow_step = global_step_rank.get("maximum", {}).get("step") + if slow_step is not None: + info_msg += f" and step {slow_step}" + logger.info(info_msg) + + analysis_profiling_path = self._get_profiling_path_by_rank(profiling_path, slow_rank_id) + + job_list += self.memory_analysis(analysis_profiling_path, step=slow_step, rank=slow_rank_id) + return job_list + + def _do_analysis(self, dimensions, pid=0, async_resp=None, **kwargs): + self.dimensions = dimensions + self.kwargs = kwargs + result_list = [] + profiling_path = PathManager.get_realpath(self.kwargs.get("profiling_path")) + benchmark_profiling_path = self.kwargs.get("benchmark_profiling_path") + if benchmark_profiling_path: + benchmark_profiling_path = PathManager.get_realpath(benchmark_profiling_path) + + if not self._check_profiling_path_valid(profiling_path): + error_msg = f"Got invalid argument '-d/--profiling_path' {profiling_path}, skip analysis" + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + # 暂不支持Mindspore数据,支持后可删除该限制 + if self._whether_include_mindspore_prof(profiling_path): + error_msg = f"Got *_ascend_ms dirs from {profiling_path}, skip analysis" + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.FAILED_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + if benchmark_profiling_path and not self._check_profiling_path_valid(benchmark_profiling_path): + error_msg = (f"Got invalid argument '-bp/--benchmark_profiling_path' {benchmark_profiling_path}, " + f"skip analysis") + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + self._is_cluster = self._is_cluster_profiling(profiling_path) + if benchmark_profiling_path: + # 构建benchmark profiling的map,用于根据rank获取profiling路径,否则无法进行比对 + is_benchmark_cluster = self._is_cluster_profiling(benchmark_profiling_path) + is_comparison_path_valid = (self._is_cluster and is_benchmark_cluster) or ( + not self._is_cluster and not is_benchmark_cluster) + if not is_comparison_path_valid: + error_msg = f"Only support profiling comparison for '1 npu vs 1 gpu/npu' and 'multi npus vs multi npus'" + self._update_analysis_process_resp(pid, async_resp, error_msg=error_msg, + status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED) + logger.error(error_msg) + return + + if not self._is_cluster: + job_list = self.single_rank_analysis(profiling_path, benchmark_profiling_path) + else: + self.slow_rank_analyzer = SlowRankAnalyzer(profiling_path) + self.slow_link_analyzer = SlowLinkAnalyzer(profiling_path) + job_list = self.do_cluster_analysis(profiling_path, benchmark_profiling_path) + + for i, (dimension, scope, interface, kwargs) in enumerate(job_list[::-1]): + result_list.append( + interface.get_result(dimension, scope, render_html=i == len(job_list) - 1, output_dict=False, + **kwargs) + ) + + for result in result_list[::-1]: + if result and hasattr(result, "show"): + result.show() + break + self._get_analysis_finished_resp(pid, async_resp) + + def _get_scopes(self, scope=None, bandwidth_type=SlowLinkAnalyzer.SDMA): + """ + Args: + scope: analyzer type + bandwidth_type: analysis standard + Returns: + scope lists + """ + scopes = [] + if scope: + if scope in self.COMMUNICATION_MAPPING.get(bandwidth_type, self.SDMA_SUPPORT_SCOPES): + scopes.append(scope) + return scopes + for dimension in [Interface.COMMUNICATION]: + for scope_ in Interface.get_scope(dimension): + if scope_ in self.SDMA_SUPPORT_SCOPES or scope_ in self.RDMA_SUPPORT_SCOPES: + scopes.append(scope_) + return scopes + + def _communication_analysis(self, **child_kwargs): + kwargs = copy.deepcopy(self.kwargs) + job_list = [] + + kwargs["profiling_path"] = child_kwargs.get("profiling_path", "") + kwargs["benchmark_profiling_path"] = child_kwargs.get("benchmark_profiling_path", "") + kwargs["step"] = child_kwargs.get("step", -1) + kwargs["benchmark_step"] = child_kwargs.get("benchmark_step", -1) + bandwidth_type = child_kwargs.get("bandwidth_type", SlowLinkAnalyzer.SDMA) + scope = child_kwargs.get("scope", None) + + for scope_ in self._get_scopes(scope, bandwidth_type): + interface = Interface(**kwargs) + job_list.append((Interface.COMMUNICATION, scope_, interface, kwargs)) + + return job_list + + def _profiling_comparison(self, compare_profiling_list): + job_list = [] + disable_profiling_comparison = os.getenv(const.DISABLE_PROFILING_COMPARISON) + if disable_profiling_comparison is not None and disable_profiling_comparison.lower() == "true": + logger.info( + "Skip profiling comparison due to longer processing time due to env 'DISABLE_PROFILING_COMPARISON'") + return job_list + + for index, _kwargs in enumerate(compare_profiling_list): + kwargs = copy.deepcopy(self.kwargs) + kwargs.update(_kwargs) + compare_profiling_list[index] = kwargs + + compare_kwargs = { + "profiling_path": kwargs.get("profiling_path"), + "compare_profiling_list": compare_profiling_list, + } + + interface = Interface(**compare_kwargs) + job_list.append((Interface.COMPARISON, SupportedScopes.COMPARISON, interface, compare_kwargs)) + + return job_list + + def _cluster_profiling_comparison(self, profiling_path, benchmark_profiling_path): + # 从计算、下发和通信三个维度对集群profiling数据进行对比 + + job_list = [] + benchmark_profiling_path = self._get_profiling_path_by_rank(benchmark_profiling_path) + benchmark_slow_rank_analyzer = SlowRankAnalyzer(benchmark_profiling_path) + benchmark_slow_link_analyzer = SlowLinkAnalyzer(benchmark_profiling_path) + + # 计算和下发分析 + job_list += self._cluster_data_comparison(profiling_path, + benchmark_profiling_path, + self.slow_rank_analyzer, + benchmark_slow_rank_analyzer, + get_max=True) + + # 通信分析 + job_list += self._cluster_data_comparison(profiling_path, + benchmark_profiling_path, + self.slow_link_analyzer, + benchmark_slow_link_analyzer, + get_max=False) + return job_list + + def _cluster_data_comparison(self, profiling_path, benchmark_profiling_path, target_cluster_analyzer, + benchmark_cluster_analyzer, get_max=False): + # #low rank/slow link结果逐行对比获取差值最大的rank和step进行单卡分析 + job_list = [] + + if isinstance(target_cluster_analyzer, SlowRankAnalyzer): + comparison_dims = [SlowRankAnalyzer.COMPUTE, SlowRankAnalyzer.FREE] + comparison_modes = [CompareConstant.KERNEL_COMPARE, CompareConstant.API_COMPARE] + elif isinstance(target_cluster_analyzer, SlowLinkAnalyzer): + comparison_dims = [SlowLinkAnalyzer.SDMA_BANDWIDTH, SlowLinkAnalyzer.RDMA_BANDWIDTH] + comparison_modes = [None, None] + else: + return job_list + + target_data = target_cluster_analyzer.format_datas.get("data", []) + benchmark_data = benchmark_cluster_analyzer.format_datas.get("data", []) + headers = benchmark_cluster_analyzer.format_datas.get("headers", []) + + if len(target_data) != len(benchmark_data): + logger.warning( + "The product of ranks and steps of Benchmark profiling is not equals to target profiling, " + "skip cluster comparison.") + return job_list + + compare_profiling_list = [] + for dimension, compare_mode in zip(comparison_dims, comparison_modes): + step, benchmark_step, rank_id_for_comparison = AnalyzerController._get_step_rank_for_cluster_statistic_diff( + target_data, + benchmark_data, + headers, + dimension, + get_max=get_max + ) + + rank_profiling_path = self._get_profiling_path_by_rank(profiling_path, rank_id_for_comparison) + rank_benchmark_profiling_path = self._get_profiling_path_by_rank( + benchmark_profiling_path, + rank_id_for_comparison + ) + + if rank_id_for_comparison is None: + # rank id为空则无法获取对应rank的profiling路径,无法进行比较 + continue + + compare_profiling_list.append( + dict(profiling_path=rank_profiling_path, benchmark_profiling_path=rank_benchmark_profiling_path, + step=step, benchmark_step=benchmark_step, + rank=rank_id_for_comparison, benchmark_rank=rank_id_for_comparison, compare_mode=compare_mode) + ) + + if not compare_profiling_list: + return job_list + + job_list += self._profiling_comparison(compare_profiling_list) + return job_list + + def _is_cluster_profiling(self, profiling_path): + if os.path.isfile(profiling_path): + return False + path_list = [os.path.join(profiling_path, dir_name) for dir_name in os.listdir(profiling_path)] + ascend_pt_dirs = [path for path in path_list if os.path.isdir(path) and path.endswith("ascend_pt")] + data_processor = PytorchDataPreprocessor(ascend_pt_dirs) + + self.cluster_local_data_map[profiling_path] = data_processor.get_data_map() + + if not self.cluster_local_data_map or not self.cluster_local_data_map.get(profiling_path): + return False + + self.default_rank_id = list(self.cluster_local_data_map[profiling_path].keys())[0] + + return len(self.cluster_local_data_map[profiling_path]) >= self.CLUSTER_RANK_THRESHOLD + + def _get_profiling_path_by_rank(self, profiling_path, rank_id=None): + + if not profiling_path: + return profiling_path + + return self._get_target_profiling_path_for_local(profiling_path, rank_id) + + def _get_target_profiling_path_for_local(self, profiling_path, rank_id): + rank_id_map = self.cluster_local_data_map.get(profiling_path, {}) + if rank_id is None or not rank_id_map: + return profiling_path + + if rank_id in rank_id_map: + return rank_id_map.get(rank_id) + + local_first_rank_id = sorted(list(map(int, rank_id_map.keys())))[0] + logger.warning("Target rank id %s does not exist in local profiling data %s, use rank %s for analysis", + rank_id, profiling_path, local_first_rank_id) + return rank_id_map.get(local_first_rank_id) + + def _update_analysis_process_resp(self, pid, resp, **kwargs): + if kwargs: + resp.update(kwargs) + self.analysis_process_resp[pid] = resp + + def _get_analysis_finished_resp(self, pid, resp): + advisor_output_file_prefix = f"mstt_advisor_{Timer().strftime}" + html_path = os.path.join(Config().work_path, f"{advisor_output_file_prefix}.html") + xlsx_path = os.path.join(Config().work_path, "log", f"{advisor_output_file_prefix}.xlsx") + if os.path.exists(html_path) and os.path.exists(xlsx_path): + result_files = {"html": html_path, "xlsx": xlsx_path} + self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.NON_FAILED_STATUS_CODE, + status=AsyncAnalysisStatus.SUCCESS, result_files=result_files) + else: + self._update_analysis_process_resp(pid, resp, status_code=AsyncAnalysisStatus.BAD_REQUEST_STATUS_CODE, + status=AsyncAnalysisStatus.FAILED, + error_msg="No optimization suggestions, please check your input path.") + + def _stage_computation_analysis(self, profiling_path, stage_step_rank, job_list): + # 对不同pp stage取min max进行分析 + logger.info("Steps and ranks to be analyzed of different pipeline parallel stages are %s", + json.dumps(stage_step_rank)) + + stages_profiling_path = [] + for stage, step_rank_info in stage_step_rank.items(): + rank_id = step_rank_info.get("maximum", {}).get("rank_id") + step = step_rank_info.get("maximum", {}).get("step") + benchmark_rank_id = step_rank_info.get("minimum", {}).get("rank_id") + benchmark_step = step_rank_info.get("minimum", {}).get("step") + + info_msg = f"For {stage}, slow rank is {rank_id}" + if step: + info_msg += f", step is {step}" + logger.info(info_msg) + + stages_profiling_path.append( + dict( + stage=stage, rank=rank_id, step=step, benchmark_rank=benchmark_rank_id, + benchmark_step=benchmark_step, + profiling_path=self._get_profiling_path_by_rank(profiling_path, rank_id), + benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, benchmark_rank_id), + compare_mode=CompareConstant.KERNEL_COMPARE + ) + ) + Interface.add_analyzer(Interface.COMPUTATION, SupportedScopes.STAGE_COMPUTE, PPStageComputationAnalyzer) + compute_analysis_kwargs = {"stages_profiling_path": stages_profiling_path, "profiling_path": profiling_path} + + job_list.append((Interface.COMPUTATION, SupportedScopes.STAGE_COMPUTE, Interface(**compute_analysis_kwargs), + compute_analysis_kwargs)) + if self.kwargs.get("benchmark_profiling_path") is None: + logger.info("Enable computation comparison of fast and slow rank/step in different pp stages") + job_list += self._profiling_comparison(stages_profiling_path) + return job_list + + def _global_computation_analysis(self, profiling_path, global_step_rank, job_list): + # 不区分stage,对所有卡取Min max进行分析 + logger.info("Without pipeline parallel stage, steps and ranks to be analyzed are %s", + json.dumps(global_step_rank)) + slow_rank_id = global_step_rank.get("maximum", {}).get("rank_id") + if slow_rank_id is not None: + info_msg = f"Maximum computation time for rank {slow_rank_id}" + else: + slow_rank_id = self.default_rank_id + info_msg = f"No slow rank with computation time, analysis for default rank {slow_rank_id}" + slow_step = global_step_rank.get("maximum", {}).get("step") + # 如果没有标杆profiling数据的rank id,说明没有快慢卡问题,直接对默认rank id进行分析,因此这里取值为None + fast_rank_id = global_step_rank.get("minimum", {}).get("rank_id") + fast_step = global_step_rank.get("minimum", {}).get("step") + + if slow_step is not None: + info_msg += f" and step {slow_step}, " + if fast_rank_id is not None: + info_msg += f"minimum computation time for rank {fast_rank_id}" + if fast_step is not None: + info_msg += f" and step {fast_step}" + logger.info(info_msg) + + kwargs = dict(profiling_path=self._get_profiling_path_by_rank(profiling_path, slow_rank_id), + benchmark_profiling_path=self._get_profiling_path_by_rank(profiling_path, fast_rank_id), + step=slow_step, benchmark_step=fast_step, rank=slow_rank_id, benchmark_rank=fast_rank_id, + compare_mode=CompareConstant.KERNEL_COMPARE) + + job_list += self.computation_analysis(**kwargs) + + rank_id_valid = slow_rank_id is not None and fast_rank_id is not None and fast_rank_id != slow_rank_id + if self.kwargs.get("benchmark_profiling_path") is None and rank_id_valid: + # 当用户指定benchmark profiling path时,不进行目标集群profiling的内部快慢卡对比 + logger.info("Enable computation comparison of fast and slow rank/step") + job_list += self._profiling_comparison([kwargs]) + return job_list diff --git a/profiler/advisor/analyzer/base_analyzer.py b/profiler/advisor/analyzer/base_analyzer.py index 5f4bd3202cd2071088f25564a7d4b14144a34826..def95d8a25d83f1b5c3e4b76005b08306ba610bd 100644 --- a/profiler/advisor/analyzer/base_analyzer.py +++ b/profiler/advisor/analyzer/base_analyzer.py @@ -1,26 +1,45 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from functools import wraps from typing import Dict, List, Union from abc import abstractmethod, ABCMeta from profiler.advisor.common import constant +from profiler.advisor.common.enum_params_parser import EnumParamsParser from profiler.advisor.common.version_control import VersionControl from profiler.advisor.dataset.dataset import Dataset from profiler.advisor.result.result import OptimizeResult from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.utils.utils import safe_division logger = logging.getLogger() class BaseAnalyzer(VersionControl, metaclass=ABCMeta): - _SUPPORT_VERSIONS = constant.SUPPORTED_CANN_VERSION + _SUPPORT_VERSIONS = EnumParamsParser().get_options(constant.CANN_VERSION) + ANALYZER_HIGH_PRIORITY_TIME_RATIO = 0.05 + ANALYZER_MEDIUM_PRIORITY_TIME_RATIO = 0.03 dataset_cls_list = [] def __init__(self, collection_path, n_processes: int = 1, **kwargs): self.n_processes = n_processes - self.cann_version = kwargs.get("cann_version", constant.DEFAULT_CANN_VERSION) - self.torch_version = kwargs.get("torch_version", constant.DEFAULT_TORCH_VERSION) + self.cann_version = kwargs.get(constant.CANN_VERSION, EnumParamsParser().get_default(constant.CANN_VERSION)) + self.torch_version = kwargs.get(constant.TORCH_VERSION, EnumParamsParser().get_default(constant.TORCH_VERSION)) self.html_render = HTMLRender() self.collection_path = collection_path self.kwargs = kwargs @@ -29,6 +48,18 @@ class BaseAnalyzer(VersionControl, metaclass=ABCMeta): self.result = OptimizeResult() self.record_list: Dict[str, List] = {} + @staticmethod + def get_first_data_by_key(data, key) -> Union[Dataset, None]: + """ + get the first member from data with key + :param data: input data + :param key: data key + :return: the first dataset in dataset list + """ + if key in data and len(data[key]) > 0: + return data[key][0] + return None + @classmethod def check_data(cls, data_list: tuple): """ @@ -49,7 +80,7 @@ class BaseAnalyzer(VersionControl, metaclass=ABCMeta): return None logger.info("Enable analysis %s with %s", self.__class__.__name__, ",".join(data_list)) - return func(self) + return func(self, **kwargs) return wrapper @@ -60,19 +91,33 @@ class BaseAnalyzer(VersionControl, metaclass=ABCMeta): pass @abstractmethod - def make_record(self): - pass - - @abstractmethod - def make_render(self): + def get_priority(self, max_mem_op_dur): pass - def init_dataset_list(self)->None: + def init_dataset_list(self) -> None: dataset_cls_list = self.dataset_cls_list if len(dataset_cls_list) == 0: logger.warning(f"Analyser: %s don't rely on any dataset!", self.__class__.__name__) return + for dataset_cls in dataset_cls_list: + if dataset_cls and callable(dataset_cls): + try: + dataset = dataset_cls(collection_path=self.collection_path, data=self.dataset_list, **self.kwargs) + except Exception as e: + logger.error(e) + continue + key = dataset_cls.get_key() + if key not in self.dataset_list: + self.dataset_list[key] = [] + self.dataset_list[key].append(dataset) + + def init_dataset_list(self) -> None: + dataset_cls_list = self.dataset_cls_list + if len(dataset_cls_list) == 0: + logger.warning(f"Analyzer: %s don't rely on any dataset!", self.__class__.__name__) + return + for dataset_cls in dataset_cls_list: if dataset_cls and callable(dataset_cls): dataset = dataset_cls(collection_path=self.collection_path, data=self.dataset_list, **self.kwargs) @@ -81,14 +126,11 @@ class BaseAnalyzer(VersionControl, metaclass=ABCMeta): self.dataset_list[key] = [] self.dataset_list[key].append(dataset) - @staticmethod - def get_first_data_by_key(data, key) -> Union[Dataset, None]: - """ - get the first member from data with key - :param data: input data - :param key: data key - :return: the first dataset in dataset list - """ - if key in data and len(data[key]) > 0: - return data[key][0] - return None + def get_priority_by_time_ratio(self, dur, step_dur): + time_ratio = safe_division(dur, step_dur) + if time_ratio >= self.ANALYZER_HIGH_PRIORITY_TIME_RATIO: + return PriorityBackgroundColor.high + elif time_ratio >= self.ANALYZER_MEDIUM_PRIORITY_TIME_RATIO: + return PriorityBackgroundColor.medium + else: + return PriorityBackgroundColor.low diff --git a/profiler/advisor/analyzer/cluster/Communication_retransmission_analyzer.py b/profiler/advisor/analyzer/cluster/Communication_retransmission_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..3683ef1b44f8b6c571dd4d8fdce0d39882d342af --- /dev/null +++ b/profiler/advisor/analyzer/cluster/Communication_retransmission_analyzer.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.cluster.Communication_retransmission_checker import CommunicationRetransmissionChecker +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset + +logger = logging.getLogger() + + +class RDMARetransmissionAnalyzer(BaseAnalyzer): + dataset_cls_list = [ClusterCommunicationDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = ClusterCommunicationDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + + @BaseAnalyzer.check_data((ClusterCommunicationDataset.get_key(),)) + def optimize(self, **kwargs): + add_render_list = kwargs.get("add_render_list", True) + rdma_checker = CommunicationRetransmissionChecker(**kwargs) + rdma_checker.check_retransmission(self.dataset) + if not rdma_checker.rdma_issues: + return self.result + rdma_checker.make_record(self.result) + self.html = rdma_checker.make_render(self.html_render, add_render_list) + return self.result diff --git a/profiler/advisor/analyzer/cluster/Communication_retransmission_checker.py b/profiler/advisor/analyzer/cluster/Communication_retransmission_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..c63fc12f27acfe1bf88c832b85e7746539143162 --- /dev/null +++ b/profiler/advisor/analyzer/cluster/Communication_retransmission_checker.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Dict, List +from collections import defaultdict +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo + +logger = logging.getLogger() + + +class GroupStatistic: + def __init__(self, min_transmission_time): + self.retransmission_issue = False + self.abnormal_op_dict: Dict[str, List] = dict() + + def add_op(self, op_name: str, hccl_info: HcclInfo): + if self.abnormal_op_dict.get(op_name) is None: + self.abnormal_op_dict.setdefault(op_name, []) + self.abnormal_op_dict.get(op_name).append([hccl_info.group, op_name, hccl_info.step, hccl_info.rank, + hccl_info.get_rdma_transit_size(), + hccl_info.get_rdma_transmit_time(), hccl_info.get_rdma_bandwidth()]) + + +class CommunicationRetransmissionChecker: + def __init__(self, **kwargs): + self.rdma_issues = False + self.desc = "" + self.sdma_desc = "" + self.rdma_desc = "" + self.suggestions = [] + self.abnormal_group_count = 0 + self.abnormal_rdma_list = [] + self.step_id = kwargs.get("step") + self.stage = None + self.group_statistics = defaultdict(GroupStatistic) + self.headers = [ + "Communication group", + "Op name", + "Step id", + "Rank id", + "RDMA transmit size(MB)", + "RDMA transmit time(ms)", + "RDMA bandwidth", + ] + self._init_rule() + + def check_possible_retransmission_occurrence(self, hccl_list: List[HcclInfo]): + min_elapse_time = min(hccl.elapse_time for hccl in hccl_list) + max_transit_time = max(hccl.rdma_info.get('Transit Time(ms)', 0) for hccl in hccl_list) + if min_elapse_time < self.min_retransmission_time: # 检测是否是卡间不同步问题,而不是重传 + return False + return max_transit_time > self.min_retransmission_time + + def check_retransmission(self, hccl_dataset: ClusterCommunicationDataset): + """ + :Param event_dataset: dataset of timeline event + """ + for group_name, hccl_group_dict in hccl_dataset.hccl_dict.items(): + for op_name, hccl_op_dict in hccl_group_dict.items(): + for step_id, hccl_list in hccl_op_dict.items(): + if self.step_id and step_id != self.step_id: # 传输指定step(self.step_id)情况下,非目标step跳过 + continue + if not self.check_possible_retransmission_occurrence(hccl_list): + continue + self.rdma_issues = True + if self.group_statistics.get(group_name) is None: + self.group_statistics.setdefault(group_name, GroupStatistic(self.min_retransmission_time)) + self.abnormal_group_count += 1 + for hccl_info in hccl_list: + if hccl_info.rdma_info.get('Transit Size(MB)', 0): + transit_time = hccl_info.rdma_info.get('Transit Time(ms)', 0) + if transit_time > self.min_retransmission_time: + self.group_statistics.get(group_name).add_op(op_name, hccl_info) + if self.rdma_issues: + self.desc = self.desc.format(group_count=self.abnormal_group_count) + for _, group_statistic in self.group_statistics.items(): + for _, op_list in group_statistic.abnormal_op_dict.items(): + for op in op_list: + self.abnormal_rdma_list.append(op) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem("Communication retransmission analysis", self.desc, self.suggestions) + result.add(OptimizeRecord(optimization_item)) + + sub_table_name = \ + "Comm Retransmission Analysis" if not self.stage else f"Stage-{self.stage}: Comm Retransmission Analysis" + result.add_detail(sub_table_name, headers=self.headers) + + for row in self.abnormal_rdma_list: + result.add_detail(sub_table_name, detail=row) + + def make_render(self, html_render, add_render_list=True): + return html_render.render_template(key="cluster", + template_dir="templates", + template_name="communication_retransmission_analysis.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + data=self.abnormal_rdma_list + ) + + def _init_rule(self): + syncbn_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "rdma_analysis.yaml" + ) + + syncbn_rule = FileManager.read_yaml_file(syncbn_rule_path) + self.desc = syncbn_rule.get("problem") + self.min_retransmission_time = syncbn_rule.get("min_retransmission_time") + + self.solutions = syncbn_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/cluster/slow_link_analyser.py b/profiler/advisor/analyzer/cluster/slow_link_analyser.py index 846b79a50f31abb8445a0e5c2e82aaaf3c8ee23d..0b585cbc7c5f136b15cd9eb035ea2dac5caa9e4e 100644 --- a/profiler/advisor/analyzer/cluster/slow_link_analyser.py +++ b/profiler/advisor/analyzer/cluster/slow_link_analyser.py @@ -19,7 +19,7 @@ from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer from profiler.advisor.common import constant from profiler.advisor.result.result import OptimizeResult from profiler.advisor.result.item import OptimizeItem, OptimizeRecord -from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataSet +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset class SlowLinkAnalyzer(BaseAnalyzer): @@ -35,11 +35,11 @@ class SlowLinkAnalyzer(BaseAnalyzer): SDMA = "SDMA" RDMA = "RDMA" SLOW_LINK_ANALYSIS = "slow_link_analysis" - dataset_cls_list = [ClusterCommunicationDataSet] + dataset_cls_list = [ClusterCommunicationDataset] def __init__(self, collection_path, n_processes: int = 1, **kwargs): super().__init__(collection_path, n_processes, **kwargs) - key = ClusterCommunicationDataSet.get_key() + key = ClusterCommunicationDataset.get_key() self.communication_data_class = self.get_first_data_by_key(self.dataset_list, key) self.rank_bw_dict = self.communication_data_class.get_data() self.result = OptimizeResult() @@ -49,8 +49,9 @@ class SlowLinkAnalyzer(BaseAnalyzer): def optimize(self, **kwargs): if self.rank_bw_dict is None: - print("slow_link 分析失败,原因是数据加载失败,请检查你的cluster_analysis_outpu文件夹, \ - 如不关心这类数据请忽略") + print("Slow link analysis failed due to data loading failure. \ + Please check your cluster_analysis_output folder. \ + If you are not concerned about this type of data, please ignore this message.") return self.result self.process() self.format_datas = self.format_details() @@ -65,8 +66,11 @@ class SlowLinkAnalyzer(BaseAnalyzer): def produce_bottleneck(self, link_type: str): data_list = [rank_dict.get(link_type, 0) for rank_id, rank_dict in self.rank_bw_dict.items()] - avg_bw = round(sum(data_list) / len(data_list), 3) - if avg_bw == 0: + if len(data_list) > 0: + avg_bw = round(sum(data_list) / len(data_list), 3) + else: + print("The slow link (identified bottleneck) cannot provide a bottleneck \ + because the analysis data is missing bandwidth information.") return self.bottelneck += f'{link_type}: \n' \ f' The average is {avg_bw}, \n' \ diff --git a/profiler/advisor/analyzer/cluster/slow_link_analyzer.py b/profiler/advisor/analyzer/cluster/slow_link_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..259e5eb0c4255afc97aad83210b72a14b7285888 --- /dev/null +++ b/profiler/advisor/analyzer/cluster/slow_link_analyzer.py @@ -0,0 +1,195 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import Dict, List +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.common import constant +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset +from profiler.advisor.utils.utils import safe_index_value, convert_to_int + +logger = logging.getLogger() + + +class SlowLinkAnalyzer(BaseAnalyzer): + RDMA_TIME_MS = "RDMA time(ms)" + RDMA_SIZE_MB = "RDMA size(mb)" + SDMA_TIME_MS = "SDMA time(ms)" + SDMA_SIZE_MB = "SDMA size(mb)" + RDMA_BANDWIDTH = "RDMA bandwidth(GB/s)" + SDMA_BANDWIDTH = "SDMA bandwidth(GB/s)" + COMMUNICATION_BANDWIDTH_INFO = "Communication Bandwidth Info" + TRANSIT_TIME = "Transit Time(ms)" + TRANSIT_SIZE = "Transit Size(MB)" + SDMA = "SDMA" + RDMA = "RDMA" + SLOW_LINK_ANALYSIS = "slow link" + RATIO_THRESHOLD = 0.05 + dataset_cls_list = [ClusterCommunicationDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs): + super().__init__(collection_path, n_processes, **kwargs) + key = ClusterCommunicationDataset.get_key() + self.communication_data_class = self.get_first_data_by_key(self.dataset_list, key) + self.rank_bw_dict = self.communication_data_class.get_data() + self.result = OptimizeResult() + self.bottelneck = '' + self.suggestion = '' + self.format_datas = {} + if self.rank_bw_dict is not None: + self.format_datas = self.format_details() + + @staticmethod + def compute_max_gap_ratio(data: list, mean: float): + if mean == 0: + return 0 + else: + return (max(data) - min(data)) / mean + + def optimize(self, **kwargs): + if self.rank_bw_dict is None: + logger.error("Slow link analysis failed due to data loading failure. \ + Please check your cluster_analysis_output folder. \ + If you are not concerned about this type of data, please ignore this message.") + return self.result + self.process() + self.make_record() + self.make_render(kwargs.get("template_key")) + return self.result + + def process(self): + if self.rank_bw_dict: + self.produce_bottleneck(self.RDMA_BANDWIDTH) + self.produce_bottleneck(self.SDMA_BANDWIDTH) + + def produce_bottleneck(self, link_type: str): + data_list = [rank_dict.get(link_type, 0) for rank_id, rank_dict in self.rank_bw_dict.items()] + if len(data_list) > 0: + avg_bw = round(sum(data_list) / len(data_list), 3) + else: + logger.info("The slow link (identified bottleneck) cannot provide a bottleneck \ + because the analysis data is missing bandwidth information.") + return + self.bottelneck += f'{link_type}: \n' \ + f' The average is {avg_bw}, \n' \ + f' while the maximum is {round(max(data_list), 3)}GB/s \n' \ + f' and the minimum is {round(min(data_list), 3)}GB/s. \n' \ + f' the difference is {round(max(data_list) - min(data_list), 3)}GB/s. \n' + + def format_details(self): + if not self.rank_bw_dict: + return { + "headers": [], + "data": [] + } + + details_dict = {} + headers = list({k for rank_bw_value in self.rank_bw_dict.values() for k in rank_bw_value.keys()}) + headers.sort() + + data_list = [] + for step_rank, rank_bw in self.rank_bw_dict.items(): + step_rank_list = list(map(convert_to_int, step_rank.split(constant.STEP_RANK_SEP))) + value_list = [rank_bw.get(i, 0) for i in headers] + data_list.append(step_rank_list + value_list) + data_list.sort(key=lambda x: (x[0], x[1])) # 按rank_id排序 + + details_dict["headers"] = ["step", "rank_id"] + headers + details_dict["data"] = data_list + + return details_dict + + def make_record(self): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem( + SlowLinkAnalyzer.SLOW_LINK_ANALYSIS, + self.bottelneck, + self.suggestion + ) + self.result.add(OptimizeRecord(optimization_item)) + + data_list = self.format_datas.get("data", []) + headers = self.format_datas.get("headers", []) + for data in data_list: + self.result.add_detail(SlowLinkAnalyzer.SLOW_LINK_ANALYSIS, headers, data) + + def make_render(self, template_key="cluster"): + result_for_html = { + "Description": self.bottelneck, + "suggestion": self.suggestion, + "details": [self.format_datas] + } + + self.html_render.render_template(key=template_key, + title=SlowLinkAnalyzer.SLOW_LINK_ANALYSIS, + template_dir="templates", + template_name="cluster_analysis.html", + cann_version=self.cann_version, + torch_version=self.torch_version, + result=result_for_html) + + def get_global_step_rank(self, bindwidth_type): + global_step_rank = {} + if not self.format_datas: + return global_step_rank + + bindwidth_key_map = {self.RDMA: self.RDMA_BANDWIDTH, self.SDMA: self.SDMA_BANDWIDTH} + + if bindwidth_type not in bindwidth_key_map: + raise RuntimeError(f"Error bindwidth type {bindwidth_type}, optionals are {bindwidth_key_map.keys()}") + + headers = self.format_datas.get("headers") + + bindwidth_index = safe_index_value(headers, bindwidth_key_map.get(bindwidth_type)) + + if bindwidth_index is not None: + data_list = [tuple_list[bindwidth_index] for tuple_list in self.format_datas.get("data", [])] + max_bandwidth, min_bandwidth = max(data_list), min(data_list) + + if self.compute_max_gap_ratio(data_list, sum(data_list) / len( + data_list)) < self.RATIO_THRESHOLD: + return global_step_rank + + max_bandwidth_index = data_list.index(max_bandwidth) + min_bandwidth_index = data_list.index(min_bandwidth) + + rank_id_index = safe_index_value(headers, "rank_id") + step_index = safe_index_value(headers, "step") + + if rank_id_index is None: + return global_step_rank + + max_bandwidth_rank_id = self.format_datas.get("data")[max_bandwidth_index][rank_id_index] + min_bandwidth_rank_id = self.format_datas.get("data")[min_bandwidth_index][rank_id_index] + + if step_index is None: + max_bandwidth_step, min_bandwidth_step = constant.DEFAULT_STEP, constant.DEFAULT_STEP + else: + max_bandwidth_step = self.format_datas.get("data")[max_bandwidth_index][step_index] + min_bandwidth_step = self.format_datas.get("data")[min_bandwidth_index][step_index] + + global_step_rank["maximum"] = {"rank_id": max_bandwidth_rank_id, "step": max_bandwidth_step} + global_step_rank["minimum"] = {"rank_id": min_bandwidth_rank_id, "step": min_bandwidth_step} + + return global_step_rank + + def get_priority(self): + pass diff --git a/profiler/advisor/analyzer/cluster/slow_rank_analyser.py b/profiler/advisor/analyzer/cluster/slow_rank_analyser.py index aa0ddad5078252d61bf92b2be10f33dc56f85ab4..f439b31f7736ee4777d5ef10bf968738a76ae1b3 100644 --- a/profiler/advisor/analyzer/cluster/slow_rank_analyser.py +++ b/profiler/advisor/analyzer/cluster/slow_rank_analyser.py @@ -19,7 +19,7 @@ from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer from profiler.advisor.common import constant from profiler.advisor.result.result import OptimizeResult from profiler.advisor.result.item import OptimizeItem, OptimizeRecord -from profiler.advisor.dataset.cluster.cluster_dataset import ClusterStepTraceTimeDataSet +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterStepTraceTimeDataset class SlowRankAnalyzer(BaseAnalyzer): @@ -27,11 +27,11 @@ class SlowRankAnalyzer(BaseAnalyzer): RANK = "rank" RATIO_THRESHOLD = 0.05 BOTTLENECK_LIST = ['Computing', 'Communication', "Free"] - dataset_cls_list = [ClusterStepTraceTimeDataSet] + dataset_cls_list = [ClusterStepTraceTimeDataset] def __init__(self, collection_path, n_processes: int = 1, **kwargs): super().__init__(collection_path, n_processes, **kwargs) - key = ClusterStepTraceTimeDataSet.get_key() + key = ClusterStepTraceTimeDataset.get_key() self.step_trace_class = self.get_first_data_by_key(self.dataset_list, key) self.step_trace_dict = self.step_trace_class.get_data() self.result = OptimizeResult() diff --git a/profiler/advisor/analyzer/cluster/slow_rank_analyzer.py b/profiler/advisor/analyzer/cluster/slow_rank_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..1cff3b87f157d8a4cff795b4fc1035581dce5caa --- /dev/null +++ b/profiler/advisor/analyzer/cluster/slow_rank_analyzer.py @@ -0,0 +1,228 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.common import constant +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterStepTraceTimeDataset +from profiler.advisor.utils.utils import safe_index_value, safe_division, convert_to_int + +logger = logging.getLogger() + + +class SlowRankAnalyzer(BaseAnalyzer): + SLOW_RANK_ANALYSIS = "slow rank" + RANK = "rank" + RATIO_THRESHOLD = 0.05 + BOTTLENECK_LIST = ['Computing', 'Communication', "Free"] + dataset_cls_list = [ClusterStepTraceTimeDataset] + COMPUTE = "compute(us)" + FREE = "free(us)" + COMMUNICATION = "communication(us)" + + def __init__(self, collection_path, n_processes: int = 1, **kwargs): + super().__init__(collection_path, n_processes, **kwargs) + key = ClusterStepTraceTimeDataset.get_key() + self.step_trace_class = self.get_first_data_by_key(self.dataset_list, key) + self.step_trace_dict = self.step_trace_class.get_data() + self.stages = self.step_trace_class.get_stages() + self.result = OptimizeResult() + self.bottelneck = '' + self.suggestion = '' + self._steps = set() + self.format_datas = {} + if self.step_trace_dict is not None: + self.format_datas = self.format_details() + + @property + def steps(self): + return sorted(list(self._steps)) + + @staticmethod + def compute_max_gap_ratio(data: list, mean: float): + if mean == 0: + return 0 + else: + return (max(data) - min(data)) / mean + + def optimize(self, **kwargs): + if self.step_trace_dict is None: + logger.error( + "Slow rank analysis failed, " + "please ensure file 'step_trace_time.csv' exists in your profiling directory %s", + constant.ASCEND_PROFILER_OUTPUT) + return self.result + self.process() + self.make_record() + self.make_render(kwargs.get("template_key")) + return self.result + + def process(self): + total_time_list = [sum(data_tuple) for rank_id, data_tuple in self.step_trace_dict.items()] + if total_time_list: + mean_total_time = sum(total_time_list) / len(total_time_list) + for i in range(len(self.BOTTLENECK_LIST)): + self.produce_bottleneck(self.step_trace_dict, i, mean_total_time) + + if not self.bottelneck: + self.bottelneck = "There is no slow rank issues" + + def produce_bottleneck(self, step_dict: dict, produce_type: int, mean_total_time: float): + data_list = [data_tuple[produce_type] for rank_id, data_tuple in step_dict.items()] + max_ratio = self.compute_max_gap_ratio(data_list, mean_total_time) + if max_ratio > self.RATIO_THRESHOLD: + self.bottelneck += f'{self.BOTTLENECK_LIST[produce_type]} \n' \ + f' has some issues in the cluster, \n' \ + f' because the max difference of {self.BOTTLENECK_LIST[produce_type]} time \n' \ + f' has reached {round(max_ratio * mean_total_time / 1000, 3)}ms. \n' + + def make_record(self): + """ + make record for what and how to optimize + """ + + optimization_item = OptimizeItem( + SlowRankAnalyzer.SLOW_RANK_ANALYSIS, + self.bottelneck, + self.suggestion + ) + self.result.add(OptimizeRecord(optimization_item)) + + data_list = self.format_datas.get("data", []) + headers = self.format_datas.get("headers", []) + for data in data_list: + self.result.add_detail(SlowRankAnalyzer.SLOW_RANK_ANALYSIS, headers, data) + + def format_details(self): + details_dict = {} + headers = ["step", "rank_id", "compute(us)", "communication(us)", "free(us)"] + data_list = [] + for key, value in self.step_trace_dict.items(): + step, rank_id = key.split(constant.STEP_RANK_SEP) + data_list.append([convert_to_int(step), convert_to_int(rank_id)] + value) + if step and step not in self._steps: + self._steps.add(step) + + details_dict["headers"] = headers + details_dict["data"] = sorted(data_list, key=lambda x: (x[0], x[1])) + return details_dict + + def make_render(self, template_key="cluster"): + result_for_html = { + "Description": self.bottelneck, + "suggestion": self.suggestion, + "details": [self.format_datas] + } + + self.html_render.render_template(key=template_key, + title=SlowRankAnalyzer.SLOW_RANK_ANALYSIS, + template_dir="templates", + template_name="cluster_analysis.html", + cann_version=self.cann_version, + torch_version=self.torch_version, + result=result_for_html) + + def get_global_step_rank(self, dimension): + global_step_rank = {} + if not self.format_datas: + return global_step_rank + + headers = self.format_datas.get("headers") + + dimension_index = safe_index_value(headers, dimension) + rank_id_index = safe_index_value(headers, "rank_id") + step_index = safe_index_value(headers, "step") + if dimension_index is None or rank_id_index is None: + return global_step_rank + + data_list = [tuple_list[dimension_index] for tuple_list in self.format_datas.get("data")] + max_time, min_time = max(data_list), min(data_list) + + if self.compute_max_gap_ratio(data_list, sum(data_list) / len( + data_list)) < self.RATIO_THRESHOLD: + logger.info("There is no significant difference in computation time among all ranks") + return global_step_rank + max_time_index = data_list.index(max_time) + min_time_index = data_list.index(min_time) + + max_time_rank_id = self.format_datas.get("data")[max_time_index][rank_id_index] + min_time_rank_id = self.format_datas.get("data")[min_time_index][rank_id_index] + + if step_index is not None: + max_time_step = self.format_datas.get("data")[max_time_index][step_index] + min_time_step = self.format_datas.get("data")[min_time_index][step_index] + else: + max_time_step, min_time_step = constant.DEFAULT_STEP, constant.DEFAULT_STEP + + global_step_rank["maximum"] = {"rank_id": max_time_rank_id, "step": max_time_step} + global_step_rank["minimum"] = {"rank_id": min_time_rank_id, "step": min_time_step} + + return global_step_rank + + def get_stage_step_rank(self, dimension): + stage_step_rank = {} + if not self.format_datas: + return stage_step_rank + + headers = self.format_datas.get("headers") + dimension_index = safe_index_value(headers, dimension) + rank_id_index = safe_index_value(headers, "rank_id") + step_index = safe_index_value(headers, "step") + if dimension_index is None or rank_id_index is None: + return stage_step_rank + + rank_list = [tuple_list[rank_id_index] for tuple_list in self.format_datas.get("data")] + cost_time_list = [tuple_list[dimension_index] for tuple_list in self.format_datas.get("data")] + + if step_index is not None: + step_list = [tuple_list[step_index] for tuple_list in self.format_datas.get("data")] + else: + step_list = [constant.DEFAULT_STEP] * len(rank_list) + + for index, stage in enumerate(self.stages): + tmp_step_list, tmp_rank_list, tmp_time_list = [], [], [] + for step, rank_id, time in zip(step_list, rank_list, cost_time_list): + if rank_id not in stage: + continue + + tmp_step_list.append(step) + tmp_rank_list.append(rank_id) + tmp_time_list.append(time) + + if self.compute_max_gap_ratio(tmp_time_list, safe_division(sum(tmp_time_list), len( + tmp_time_list))) < self.RATIO_THRESHOLD: + continue + + max_time, min_time = max(tmp_time_list), min(tmp_time_list) + max_time_index, min_time_index = tmp_time_list.index(max_time), tmp_time_list.index(min_time) + + stage_key = f"stage-{index}" + stage_step_rank[stage_key] = {} + stage_step_rank[stage_key]["maximum"] = { + "rank_id": tmp_rank_list[max_time_index], + "step": tmp_step_list[max_time_index], + } + stage_step_rank[stage_key]["minimum"] = { + "rank_id": tmp_rank_list[min_time_index], + "step": tmp_step_list[min_time_index], + } + + return stage_step_rank + + def get_priority(self): + pass diff --git a/profiler/advisor/analyzer/communication/base_communication_analyzer.py b/profiler/advisor/analyzer/communication/base_communication_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc515e7a6fa9990046f1061aeae39e43e730b8a --- /dev/null +++ b/profiler/advisor/analyzer/communication/base_communication_analyzer.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer + + +class BaseCommunicationAnalyzer(BaseAnalyzer): + requires_cluster_dataset = True + + def __init__(self, collection_path, n_processes: int = 1, **kwargs): + super().__init__(collection_path, n_processes, **kwargs) diff --git a/profiler/advisor/analyzer/communication/contention/__init__.py b/profiler/advisor/analyzer/communication/contention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/communication/contention/bandwidth_contention_analyzer.py b/profiler/advisor/analyzer/communication/contention/bandwidth_contention_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..46c8be7a482b38fb32234eca179b8cf89f992f32 --- /dev/null +++ b/profiler/advisor/analyzer/communication/contention/bandwidth_contention_analyzer.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.communication.base_communication_analyzer import BaseCommunicationAnalyzer +from profiler.advisor.analyzer.communication.contention.bandwidth_contention_checker import BandwidthContentionChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.communication.communication_dataset import CommunicationDataset +from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset +from profiler.advisor.result.result import OptimizeResult + +logger = logging.getLogger() + + +class BandwidthContentionAnalyzer(BaseCommunicationAnalyzer): + dataset_cls_list = [ProfilingDataset, CommunicationDataset] + requires_cluster_dataset = False + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + communication_key = CommunicationDataset.get_key() + profiling_key = ProfilingDataset.get_key() + self.communication_dataset = self.get_first_data_by_key(self.dataset_list, communication_key) + self.profiling_dataset = self.get_first_data_by_key(self.dataset_list, profiling_key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + + @BaseCommunicationAnalyzer.check_data((CommunicationDataset.get_key(),)) + def optimize(self, **kwargs): + add_render_list = kwargs.get("add_render_list", True) + bandwidth_contention_checker = BandwidthContentionChecker(**kwargs) + bandwidth_contention_checker.check_contention(self.communication_dataset, self.profiling_dataset) + if not bandwidth_contention_checker.contention_issues: + return self.result + bandwidth_contention_checker.make_record(self.result) + self.html = bandwidth_contention_checker.make_render(self.html_render, add_render_list, + priority=self.get_priority()) + return self.result + + def get_priority(self): + # 提升1% ~ 3% + return PriorityBackgroundColor.low diff --git a/profiler/advisor/analyzer/communication/contention/bandwidth_contention_checker.py b/profiler/advisor/analyzer/communication/contention/bandwidth_contention_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b85592d995024ef392ccd1b44eeb3fa1ab00eb --- /dev/null +++ b/profiler/advisor/analyzer/communication/contention/bandwidth_contention_checker.py @@ -0,0 +1,175 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import List +from profiler.advisor.dataset.communication.communication_dataset import CommunicationDataset +from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.utils.utils import convert_to_float +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo +from profiler.advisor.dataset.profiling.info_collection import OpInfo + +logger = logging.getLogger() + + +class SDMAOperator: + def __init__(self, hccl_info: HcclInfo): + self._ts = hccl_info.ts + self._dur = hccl_info.elapse_time + self._name = hccl_info.name + self._bandwidth = hccl_info.sdma_info.get('Bandwidth(GB/s)', 0) + + @property + def ts(self): + return self._ts + + @property + def dur(self): + return self._dur + + @property + def end(self): + return self._ts + self._dur * 1000 + + @property + def name(self): + return self._name + + @property + def bandwidth(self): + return self._bandwidth + + +class BandwidthContentionChecker: + _CHECKER = "BandwidthContentionChecker" + + def __init__(self, **kwargs): + self.contention_issues = False + self.desc = "" + self.step_id = kwargs.get("step") + self.stage = None + self.threshold = 0 + self.contention_topk = 0 + self.sdma_list: List[SDMAOperator] = [] + self.matmul_list: List[OpInfo] = [] + self.abnormal_sdma_list: List[SDMAOperator] = [] + self.suggestions = [] + self._init_rule() + self.headers = ["op name", "duration(ms)", "bandwidth(GB/s)"] + + @staticmethod + def check_sdma_operator(hccl_op: HcclInfo): + if hccl_op.sdma_info: + if hccl_op.sdma_info.get('Transit Size(MB)', 0): + return True + return False + + def export_sdma_list(self): + res = [] + for hccl_op in self.abnormal_sdma_list: + res.append([hccl_op.name, round(hccl_op.dur, 4), round(hccl_op.bandwidth, 2)]) + res.sort(key=lambda x: x[2]) + return res[:min(len(res), self.contention_topk)] + + def check_task_dict(self, profiling_dataset: ProfilingDataset) -> bool: + if not hasattr(profiling_dataset, "op_summary"): + logger.warning("Skip %s checker because of not containing %s", self._CHECKER, "op summary") + return False + if not hasattr(profiling_dataset.op_summary, "task_dict"): + logger.warning("Skip %s checker because of not containing %s", self._CHECKER, "op summary") + return False + return True + + def extract_matmul_operator(self, profiling_dataset: ProfilingDataset): + for key, value in profiling_dataset.op_summary.task_dict.items(): + if "matmul" in key.lower(): + self.matmul_list.extend(value) + self.matmul_list.sort(key=lambda x: convert_to_float(x.task_start_time)) + + def extract_sdma_operator(self, hccl_dataset: CommunicationDataset): + for step_id, step_data in hccl_dataset.hccl_dict.items(): + if self.step_id is not None and step_id != self.step_id: + continue + for hccl_op in step_data: + if self.check_sdma_operator(hccl_op): + self.sdma_list.append(SDMAOperator(hccl_op)) + self.sdma_list.sort(key=lambda x: x.ts) + + def check_contention(self, hccl_dataset: CommunicationDataset, profiling_dataset: ProfilingDataset) -> None: + if not self.check_task_dict(profiling_dataset): + return + self.extract_matmul_operator(profiling_dataset) + self.extract_sdma_operator(hccl_dataset) + hccl_index = 0 + matmul_index = 0 + while hccl_index < len(self.sdma_list) and matmul_index < len(self.matmul_list): + if self.sdma_list[hccl_index].end < self.matmul_list[matmul_index].get_float_attr("task_start_time"): + hccl_index += 1 + elif self.matmul_list[matmul_index].get_float_attr("task_start_time") + \ + self.matmul_list[matmul_index].get_float_attr("task_duration") < self.sdma_list[hccl_index].ts: + matmul_index += 1 + else: + if self.sdma_list[hccl_index].bandwidth < self.threshold: + self.abnormal_sdma_list.append(self.sdma_list[hccl_index]) + matmul_index += 1 + if self.abnormal_sdma_list: + self.contention_issues = True + self.desc = self.desc.format(threshold=self.threshold) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem("bandwidth contention analysis", self.desc, self.suggestions) + result.add(OptimizeRecord(optimization_item)) + + sub_table_name = "Bandwidth Contention Analysis" if not self.stage else f"Stage-{self.stage}: " \ + f"Bandwidth Contention Analysis" + result.add_detail(sub_table_name, headers=self.headers) + for hccl_op in self.abnormal_sdma_list: + result.add_detail(sub_table_name, detail=[hccl_op.name, round(hccl_op.dur, 4), round(hccl_op.bandwidth, 2)]) + + def make_render(self, html_render, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="communication", + template_dir="templates", + template_name="contention.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + data=self.export_sdma_list(), + topk=self.contention_topk, + priority_background_color=priority) + + def _init_rule(self): + contention_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + "rules", + "bandwidth_contention.yaml" + ) + + contention_rule = FileManager.read_yaml_file(contention_rule_path) + self.desc = contention_rule.get("problem") + self.threshold = contention_rule.get("threshold", 0) * contention_rule.get("sdma_baseline", 0) + self.contention_topk = contention_rule.get("top_num", 3) + self.solutions = contention_rule.get("solutions") + if not self.desc or not self.solutions or not isinstance(self.solutions, list): + raise RuntimeError("The configuration file of the bandwidth contention analyzer is abnormal. Please check.") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/communication/packet/__init__.py b/profiler/advisor/analyzer/communication/packet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/communication/packet/packet_analyzer.py b/profiler/advisor/analyzer/communication/packet/packet_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..444b643a1c6914b36476145caa866be81fcf65a4 --- /dev/null +++ b/profiler/advisor/analyzer/communication/packet/packet_analyzer.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.communication.base_communication_analyzer import BaseCommunicationAnalyzer +from profiler.advisor.analyzer.communication.packet.packet_checker import PacketChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.communication.communication_dataset import CommunicationDataset +from profiler.advisor.result.result import OptimizeResult + +logger = logging.getLogger() + + +class PacketAnalyzer(BaseCommunicationAnalyzer): + dataset_cls_list = [CommunicationDataset] + requires_cluster_dataset = False + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = CommunicationDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + + @BaseCommunicationAnalyzer.check_data((CommunicationDataset.get_key(),)) + def optimize(self, **kwargs): + add_render_list = kwargs.get("add_render_list", True) + if not hasattr(self.dataset, "hccl_dict"): + return self.result + packet_checker = PacketChecker(**kwargs) + packet_checker.check_packet(self.dataset) + if not packet_checker.packet_issues: + return self.result + packet_checker.make_record(self.result) + self.html = packet_checker.make_render(self.html_render, add_render_list, priority=self.get_priority()) + return self.result + + def get_priority(self): + # 提升1% ~ 3% + return PriorityBackgroundColor.low diff --git a/profiler/advisor/analyzer/communication/packet/packet_checker.py b/profiler/advisor/analyzer/communication/packet/packet_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..6ddf17c43fdddc7adfd98866bc8869206c4cf942 --- /dev/null +++ b/profiler/advisor/analyzer/communication/packet/packet_checker.py @@ -0,0 +1,149 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from profiler.advisor.dataset.communication.communication_dataset import CommunicationDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.utils.utils import convert_to_float + +logger = logging.getLogger() + + +class Statistic: + def __init__(self, min_ratio, min_size, desc, type_): + self.issue = False + self.count = 0 + self.abnormal_count = 0 + self.abnormal_duration = 0 + self.abnormal_ratio = 0 + self.min_ratio = min_ratio + self.min_size = min_size + self.desc = desc + self.type = type_ + + def check_threshold(self): + if self.count and self.abnormal_count: + self.abnormal_ratio = self.abnormal_count / self.count + if self.abnormal_ratio > self.min_ratio: + self.issue = True + return self.issue + + def process(self, hccl_info): + info = dict() + if self.type == "SDMA": + info = hccl_info.sdma_info + elif self.type == "RDMA": + info = hccl_info.rdma_info + if info.get('Transit Size(MB)', 0): + packet_size = info.get('Transit Size(MB)', 0) + if packet_size < self.min_size: + self.abnormal_count += 1 + self.abnormal_duration += info.get('Transit Time(ms)', 0) + self.count += 1 + + def adapt(self, dst_headers: list, src_headers, datas: list): + if not self.issue: + return False + dst_headers.extend(src_headers) + datas.extend([self.count, self.abnormal_count, self.abnormal_ratio, self.abnormal_duration]) + self.desc = self.desc.format( + abnormal_ratio=f"{round(self.abnormal_ratio, 4):.2%}", + min_size=self.min_size, + abnormal_time=round(self.abnormal_duration, 4)) + return True + + +class PacketChecker: + def __init__(self, **kwargs): + self.packet_issues = False + self.desc = "" + self.sdma_desc = "" + self.rdma_desc = "" + self.suggestions = [] + self.min_sdma_size = 0 + self.min_rdma_size = 0 + self.min_sdma_ratio = 0 + self.min_rdma_ratio = 0 + self.step_id = kwargs.get("step") + self.stage = None + self.packet_issues = False + self._init_rule() + self.sdma_statistic = Statistic(self.min_sdma_ratio, self.min_sdma_size, self.sdma_desc, "SDMA") + self.rdma_statistic = Statistic(self.min_rdma_ratio, self.min_rdma_size, self.rdma_desc, "RDMA") + self.small_packet_detail = [] + self.headers = [] + self.sdma_headers = ["SDMA total count", "Small SDMA count", "Small SDMA ratio", "Small SDMA duration(ms)"] + self.rdma_headers = ["RDMA total count", "Small RDMA count", "Small RDMA ratio", "Small RDMA duration(ms)"] + + def check_packet(self, hccl_dataset: CommunicationDataset): + for step_id, hccl_list in hccl_dataset.hccl_dict.items(): + if self.step_id and step_id != self.step_id: + continue + for hccl_info in hccl_list: + self.sdma_statistic.process(hccl_info) + self.rdma_statistic.process(hccl_info) + self.sdma_statistic.check_threshold() + self.rdma_statistic.check_threshold() + if self.sdma_statistic.adapt(self.headers, self.sdma_headers, self.small_packet_detail): + self.packet_issues = True + self.desc += self.sdma_statistic.desc + if self.rdma_statistic.adapt(self.headers, self.rdma_headers, self.small_packet_detail): + self.packet_issues = True + self.desc += self.rdma_statistic.desc + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem("Packet analysis", self.desc, self.suggestions) + result.add(OptimizeRecord(optimization_item)) + + sub_table_name = "Packet Analysis" if not self.stage else f"Stage-{self.stage}: Packet Analysis" + result.add_detail(sub_table_name, headers=self.headers) + result.add_detail(sub_table_name, detail=self.small_packet_detail) + + def make_render(self, html_render, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="communication", + template_dir="templates", + template_name="packet_analysis.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + data=self.small_packet_detail, + priority_background_color=priority) + + def _init_rule(self): + syncbn_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + "rules", + "packet.yaml" + ) + + syncbn_rule = FileManager.read_yaml_file(syncbn_rule_path) + self.desc = syncbn_rule.get("problem") + self.sdma_desc = syncbn_rule.get("sdma_problem") + self.rdma_desc = syncbn_rule.get("rdma_problem") + self.min_sdma_size = convert_to_float(syncbn_rule.get("min_sdma_size")) + self.min_rdma_size = convert_to_float(syncbn_rule.get("min_rdma_size")) + self.min_sdma_ratio = convert_to_float(syncbn_rule.get("min_sdma_ratio")) + self.min_rdma_ratio = convert_to_float(syncbn_rule.get("min_rdma_ratio")) + + self.solutions = syncbn_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/communication/retransmission/__init__.py b/profiler/advisor/analyzer/communication/retransmission/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/communication/retransmission/communication_retransmission_analyzer.py b/profiler/advisor/analyzer/communication/retransmission/communication_retransmission_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..78cade900731926f6be2303fd8b9ac6072df35f7 --- /dev/null +++ b/profiler/advisor/analyzer/communication/retransmission/communication_retransmission_analyzer.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.communication.base_communication_analyzer import BaseCommunicationAnalyzer +from profiler.advisor.analyzer.communication.retransmission.communication_retransmission_checker import \ + CommunicationRetransmissionChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset +from profiler.advisor.result.result import OptimizeResult + +logger = logging.getLogger() + + +class RDMARetransmissionAnalyzer(BaseCommunicationAnalyzer): + dataset_cls_list = [ClusterCommunicationDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = ClusterCommunicationDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + + @BaseCommunicationAnalyzer.check_data((ClusterCommunicationDataset.get_key(),)) + def optimize(self, **kwargs): + add_render_list = kwargs.get("add_render_list", True) + rdma_checker = CommunicationRetransmissionChecker(**kwargs) + rdma_checker.check_retransmission(self.dataset) + if not rdma_checker.rdma_issues: + return self.result + rdma_checker.make_record(self.result) + self.html = rdma_checker.make_render(self.html_render, add_render_list, priority=self.get_priority()) + return self.result + + def get_priority(self): + # 单次重传最少4s,高优先级 + return PriorityBackgroundColor.high diff --git a/profiler/advisor/analyzer/communication/retransmission/communication_retransmission_checker.py b/profiler/advisor/analyzer/communication/retransmission/communication_retransmission_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..577f7c23ff21dccaad138acd95ca9af20faa7bab --- /dev/null +++ b/profiler/advisor/analyzer/communication/retransmission/communication_retransmission_checker.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Dict, List +from collections import defaultdict +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo + +logger = logging.getLogger() + + +class GroupStatistic: + def __init__(self, min_transmission_time): + self.retransmission_issue = False + self.abnormal_op_dict: Dict[str, List] = dict() + + def add_op(self, op_name: str, hccl_info: HcclInfo): + if self.abnormal_op_dict.get(op_name) is None: + self.abnormal_op_dict.setdefault(op_name, []) + self.abnormal_op_dict.get(op_name).append([hccl_info.group, op_name, hccl_info.step, hccl_info.rank, + hccl_info.get_rdma_transit_size(), + hccl_info.get_rdma_transmit_time(), hccl_info.get_rdma_bandwidth()]) + + +class CommunicationRetransmissionChecker: + def __init__(self, **kwargs): + self.rdma_issues = False + self.desc = "" + self.sdma_desc = "" + self.rdma_desc = "" + self.suggestions = [] + self.abnormal_group_count = 0 + self.abnormal_rdma_list = [] + self.step_id = kwargs.get("step") + self.stage = None + self.group_statistics = defaultdict(GroupStatistic) + self.headers = [ + "Communication group", + "Op name", + "Step id", + "Rank id", + "RDMA transmit size(MB)", + "RDMA transmit time(ms)", + "RDMA bandwidth", + ] + self._init_rule() + + def check_possible_retransmission_occurrence(self, hccl_list: List[HcclInfo]): + min_elapse_time = min(hccl.elapse_time for hccl in hccl_list) + max_transit_time = max(hccl.rdma_info.get('Transit Time(ms)', 0) for hccl in hccl_list) + if min_elapse_time < self.min_retransmission_time: # 检测是否是卡间不同步问题,而不是重传 + return False + return max_transit_time > self.min_retransmission_time + + def check_retransmission(self, hccl_dataset: ClusterCommunicationDataset): + """ + :Param event_dataset: dataset of timeline event + """ + for group_name, hccl_group_dict in hccl_dataset.hccl_dict.items(): + for op_name, hccl_op_dict in hccl_group_dict.items(): + for step_id, hccl_list in hccl_op_dict.items(): + if self.step_id and step_id != self.step_id: # 传输指定step(self.step_id)情况下,非目标step跳过 + continue + if not self.check_possible_retransmission_occurrence(hccl_list): + continue + self.rdma_issues = True + if self.group_statistics.get(group_name) is None: + self.group_statistics.setdefault(group_name, GroupStatistic(self.min_retransmission_time)) + self.abnormal_group_count += 1 + for hccl_info in hccl_list: + if hccl_info.rdma_info.get('Transit Size(MB)', 0): + transit_time = hccl_info.rdma_info.get('Transit Time(ms)', 0) + if transit_time > self.min_retransmission_time: + self.group_statistics.get(group_name).add_op(op_name, hccl_info) + if self.rdma_issues: + self.desc = self.desc.format(group_count=self.abnormal_group_count) + for _, group_statistic in self.group_statistics.items(): + for _, op_list in group_statistic.abnormal_op_dict.items(): + for op in op_list: + self.abnormal_rdma_list.append(op) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem("Communication retransmission analysis", self.desc, self.suggestions) + result.add(OptimizeRecord(optimization_item)) + + sub_table_name = \ + "Comm Retransmission Analysis" if not self.stage else f"Stage-{self.stage}: Comm Retransmission Analysis" + result.add_detail(sub_table_name, headers=self.headers) + + for row in self.abnormal_rdma_list: + result.add_detail(sub_table_name, detail=row) + + def make_render(self, html_render, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="communication", + template_dir="templates", + template_name="communication_retransmission_analysis.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + data=self.abnormal_rdma_list, + priority_background_color=priority) + + def _init_rule(self): + syncbn_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + "rules", + "rdma_analysis.yaml" + ) + + syncbn_rule = FileManager.read_yaml_file(syncbn_rule_path) + self.desc = syncbn_rule.get("problem") + self.min_retransmission_time = syncbn_rule.get("min_retransmission_time") + + self.solutions = syncbn_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/comparison/__init__.py b/profiler/advisor/analyzer/comparison/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/comparison/comparison_analyzer.py b/profiler/advisor/analyzer/comparison/comparison_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..b333c1174863d84f709f418081c1352be3e605d1 --- /dev/null +++ b/profiler/advisor/analyzer/comparison/comparison_analyzer.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.analyzer.comparison.comparison_checker import ComparisonChecker + +logger = logging.getLogger() + + +class ComparisonAnalyzer(BaseAnalyzer): + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + self.result = OptimizeResult() + self.html_render = HTMLRender() + + def optimize(self, compare_profiling_list, **kwargs): + for compare_profiling_path in compare_profiling_list: + self._optimize(**compare_profiling_path) + return self.result + + def get_priority(self): + pass + + def _optimize(self, profiling_path, benchmark_profiling_path, **kwargs): + comparison_checker = ComparisonChecker(profiling_path, + benchmark_profiling_path, + step=kwargs.get("step"), + benchmark_step=kwargs.get("benchmark_step"), + rank=kwargs.get("rank"), + benchmark_rank=kwargs.get("benchmark_rank")) + comparison_checker.compare(kwargs.get("compare_mode")) + comparison_checker.make_record(self.result) + comparison_checker.make_render(self.html_render) diff --git a/profiler/advisor/analyzer/comparison/comparison_checker.py b/profiler/advisor/analyzer/comparison/comparison_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4cb83d33c43614c90e198a6b35a2dc1f301782 --- /dev/null +++ b/profiler/advisor/analyzer/comparison/comparison_checker.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.utils.utils import safe_index_value, convert_to_float, convert_to_int +from profiler.compare_tools.compare_backend.utils.constant import Constant as CompareConstant +from profiler.compare_tools.compare_interface.comparison_interface import ComparisonInterface + +logger = logging.getLogger() + + +class ComparisonChecker: + BENCHMARK_PREFIX = "Benchmark " + SHOW_TOPK = 10 + DIFF_AVG_RATIO = "Diff Avg Ratio" + COMPARE_MODE_TO_DESC = { + CompareConstant.KERNEL_COMPARE: "Kernel compare", + CompareConstant.API_COMPARE: "Api compare", + } + + def __init__(self, profiling_path, benchmark_profiling_path, step=None, benchmark_step=None, rank=None, + benchmark_rank=None): + + self.profiling_path = profiling_path + self.benchmark_profiling_path = benchmark_profiling_path + self.step = ComparisonChecker.get_valid_step(step) + self.benchmark_step = ComparisonChecker.get_valid_step(benchmark_step) + self.rank = rank + self.benchmark_rank = benchmark_rank + self.compare_mode = None + self.format_result = {} + self.desc = None + self.suggestion = None + + @staticmethod + def get_valid_step(step): + none_step = None + if step is None: + return none_step + if isinstance(step, (int, float)): + if step < 0: + # 当没有step时,analyzer controller返回step=-1 + return none_step + else: + return str(convert_to_int(step)) + else: + return none_step + + def compare(self, compare_mode): + """ + :Param event_dataset: dataset of timeline event + """ + if compare_mode is None: + return + self.compare_mode = compare_mode + compare_interface = ComparisonInterface(self.profiling_path, self.benchmark_profiling_path, self.step, + self.benchmark_step) + result = compare_interface.compare(self.compare_mode) + data = result.get(self.compare_mode, {}) + headers = data.get("headers", {}) + rows = data.get("rows", []) + format_headers = [] + + for schema in headers: + name = schema.get("name", "null") + if name not in format_headers: + format_headers.append(name) + else: + format_headers.append(f"{self.BENCHMARK_PREFIX} {name}") + + if not rows: + return + + self.format_result[self.compare_mode] = {"headers": format_headers, "rows": rows} + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.format_result: + return + + sheet_name = self._get_sheet_name() + self.desc = sheet_name + optimization_item = OptimizeItem(sheet_name, self.desc, []) + result.add(OptimizeRecord(optimization_item)) + + result.add_detail(sheet_name, headers=self.format_result.get(self.compare_mode, {}).get("headers")) + + for row in self.format_result.get(self.compare_mode, {}).get("rows"): + result.add_detail(sheet_name, detail=row) + + def make_render(self, html_render, **kwargs): + if not self.format_result: + return + + headers = self.format_result.get(self.compare_mode, {}).get("headers", []) + diff_avg_index = safe_index_value(headers, self.DIFF_AVG_RATIO) + if diff_avg_index is None: + logger.warning("'%s' not exsits in headers of comparison result, skip render html.", self.DIFF_AVG_RATIO) + return + rows = self.format_result.get(self.compare_mode, {}).get("rows", []) + sorted_rows = sorted(rows, + key=lambda x: convert_to_float(x[diff_avg_index]) if diff_avg_index < len(x) else -1.0, + reverse=True) + + topk_rows = [] + if sorted_rows: + topk_rows = sorted_rows[:self.SHOW_TOPK] + + if not headers or not topk_rows: + return + + html_desc = self.desc + f". Only show {self.SHOW_TOPK} rows here, see mstt_advisor*.xlsx for details" + + html_render.render_template(key="comparison", + template_dir="templates", + template_name="comparison.html", + sheet_name=self._get_sheet_name(), + desc=html_desc, + headers=headers, + rows=topk_rows) + + def _get_sheet_name(self): + + sheet_name = "" + if self.rank is not None: + sheet_name += f"Rank{self.rank}" + if self.step is not None: + sheet_name += f" Step{self.step}" + if sheet_name: + sheet_name += " and " + if self.benchmark_rank is not None: + sheet_name += f"Rank{self.benchmark_rank}" + if self.benchmark_step is not None: + sheet_name += f" Step{self.benchmark_step}" + if not sheet_name: + sheet_name = "Target and Benchmark" + + sheet_name = f"{self.COMPARE_MODE_TO_DESC.get(self.compare_mode, '')} of {sheet_name}" + return sheet_name diff --git a/profiler/advisor/analyzer/computation/ai_core_freq/__init__.py b/profiler/advisor/analyzer/computation/ai_core_freq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_analyzer.py b/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd41df260eff3d03459e92d570923256816cb9f4 --- /dev/null +++ b/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_analyzer.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.computation.ai_core_freq.ai_core_freq_checker import AICoreFreqChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.timeline_event_dataset import ComputationAnalysisDataset +from profiler.advisor.dataset.profiling.device_info import DeviceInfoParser +from profiler.advisor.config.config import Config + +logger = logging.getLogger() + + +class AICoreFreqAnalyzer(BaseAnalyzer): + dataset_cls_list = [ComputationAnalysisDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = ComputationAnalysisDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + info = DeviceInfoParser(collection_path) + info.parse_data() + + @BaseAnalyzer.check_data((ComputationAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + if not Config().get_config("aic_frequency"): + logger.warning("Can not find ai core frequency in info.json*, please check data integrity.") + return self.result + + add_render_list = kwargs.get("add_render_list", True) + ai_core_freq_checker = AICoreFreqChecker() + ai_core_freq_checker.check_ai_core_freq(self.dataset, rank=kwargs.get("rank"), stage=kwargs.get("stage")) + ai_core_freq_checker.make_record(self.result) + self.html = ai_core_freq_checker.make_render(self.html_render, add_render_list, priority=self.get_priority(), + rank=kwargs.get("rank")) + return self.result + + def get_priority(self): + return PriorityBackgroundColor.high \ No newline at end of file diff --git a/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py b/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..f42b9514782e977c56a0f7776627beddbdebcd60 --- /dev/null +++ b/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.dataset.timeline_event_dataset import ComputationAnalysisDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.config.config import Config +from profiler.advisor.utils.utils import convert_to_float + +logger = logging.getLogger() + + +class AICoreFreqChecker: + DECREASE_FREQ_RATIO = 0.05 + SHOW_TOPK_OPS = 10 + TOTAL_DURATION_INDEX = 2 + DECREASE_FREQ_RATIO_INDEX = 3 + + def __init__(self): + + self.ai_core_freq_issues = False + self.desc = "" + self.suggestions = "" + self.decrease_freq_ops = [] + self.headers = [] + self.op_freq = None + self.rank = None + self.stage = None + + def check_ai_core_freq(self, event_dataset: ComputationAnalysisDataset, rank=None, stage=None): + """ + :Param event_dataset: dataset of timeline event + """ + if not hasattr(event_dataset, "op_freq") or not getattr(event_dataset, "op_freq"): + logger.debug("Skip slow ai core frequency checker, " + "because no ai core frequency were recorded in trace_view.json") + return + + self.rank = rank + self.stage = stage + self.op_freq = event_dataset.op_freq + for op_name, op_info in self.op_freq.items(): + freq_list = op_info.get("freq_list", []) + if not freq_list: + continue + + op_count = op_info.get("count", 0) + op_total_duration = round(op_info.get("dur", 0), 2) + max_freq = convert_to_float(Config().get_config("aic_frequency")) + + if max_freq == 0: + raise ValueError("max_freq cannot be zero.") + decrease_freq_ratio = sum(max_freq - freq for freq in freq_list) / (max_freq * len(freq_list)) + if decrease_freq_ratio >= Config().get_config("frequency_threshold"): + self.ai_core_freq_issues = True + self.decrease_freq_ops.append([op_name, op_count, op_total_duration, + f"{round(decrease_freq_ratio, 4):.2%}", + round(sum(freq_list) / len(freq_list), 2), + max(freq_list), min(freq_list)]) + + if self.decrease_freq_ops: + # 按算子总耗时和降频比率 降序排列 + self.decrease_freq_ops.sort(key = + lambda x: (x[self.TOTAL_DURATION_INDEX], x[self.DECREASE_FREQ_RATIO_INDEX]), + reverse = True) + if not self.ai_core_freq_issues: + return + + self.desc = (f"{len(self.decrease_freq_ops)} operators are found during frequency reduction, and the reduction " + f"ratio is larger than {self.DECREASE_FREQ_RATIO}.") + if self.rank: + self.desc = f"For rank {self.rank}, " + self.desc.lower() + self.suggestions = "Please check the temperature or max power of your machine." + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.ai_core_freq_issues: + return self.ai_core_freq_issues + + sheet_name = "AI Core Frequency" + if self.rank is not None: + sheet_name = f"rank {self.rank} AI Core Frequency".capitalize() + + optimization_item = OptimizeItem(sheet_name, self.desc, [self.suggestions]) + result.add(OptimizeRecord(optimization_item)) + + self.headers = [ + "Operator name", + "Count", + "Total duration(us)", + "AI CORE frequency decreased ratio", + "Average frequency", + "Max frequency", + "Min frequency", + ] + result.add_detail(sheet_name, headers=self.headers) + + for row in self.decrease_freq_ops: + result.add_detail(sheet_name, detail=row) + return True + + def make_render(self, html_render, add_render_list=True, **kwargs): + if not self.ai_core_freq_issues: + return self.ai_core_freq_issues + + priority = kwargs.get("priority") + if self.SHOW_TOPK_OPS: + self.desc += f" Only show {self.SHOW_TOPK_OPS} operators here, see latest mstt_advisor.xlsx for details." + return html_render.render_template(key="computation", + template_dir="templates", + template_name="ai_core_frequency.html", + desc=self.desc, + suggestion=self.suggestions, + headers=self.headers, + data=self.decrease_freq_ops[:self.SHOW_TOPK_OPS], + add_render_list=add_render_list, + priority_background_color=priority, + rank=kwargs.get("rank")) diff --git a/profiler/advisor/analyzer/computation/aicpu/aicpu_checker.py b/profiler/advisor/analyzer/computation/aicpu/aicpu_checker.py index 4eca1c6c0278349cf4068544d2a53d8de7f0d5e1..0c724f45aa2a65a40cb2fd53eebc84e930bd4646 100644 --- a/profiler/advisor/analyzer/computation/aicpu/aicpu_checker.py +++ b/profiler/advisor/analyzer/computation/aicpu/aicpu_checker.py @@ -1,15 +1,29 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import os from functools import partial from typing import List, Dict, Optional -import yaml from profiler.advisor.analyzer.computation.operator_checker import OperatorChecker, logger from profiler.advisor.analyzer.schedule.fusion_ops.timeline_api_stack_checker import OpStackFinder from profiler.advisor.common import constant from profiler.advisor.dataset.dataset import Dataset from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset -from profiler.advisor.dataset.timeline_event_dataset import TimelineEventDataset +from profiler.advisor.dataset.timeline_event_dataset import ComputationAnalysisDataset +from profiler.cluster_analyse.common_func.file_manager import FileManager class AicpuChecker(OperatorChecker): @@ -30,25 +44,18 @@ class AicpuChecker(OperatorChecker): self.aicpu_rules: Dict = {} self.aicpu_checker: Dict = {} self.load_aicpu_rules() + self.total_task_duration = 0.0 + self.aicpu_task_duration = 0.0 - def _check_data(self, profiling_data: ProfilingDataset) -> bool: - if not self._check_summary(profiling_data): - return False - return True - - def _check_operator(self, op_info) -> bool: - return op_info.task_type == constant.AI_CPU - - def load_aicpu_rules(self, rule_path="rules/aicpu_rules.yaml") -> Dict: + def load_aicpu_rules(self, rule_path="rules/aicpu_rules.yaml"): if not os.path.isabs(rule_path): rule_path = os.path.join(os.path.dirname(__file__), "../../../", rule_path) if not os.path.exists(rule_path): logger.warning("Skip analyze aicpu issues, because %s does not exist.", rule_path) - return {} - with open(rule_path, 'r') as f: - self.aicpu_rules = yaml.safe_load(f) + + self.aicpu_rules = FileManager.read_yaml_file(rule_path) self.filter_aicpu_rules(self.aicpu_rules) for checker_name, check_rule in self.aicpu_rules.items(): if not isinstance(check_rule, (list, dict,)): @@ -88,22 +95,26 @@ class AicpuChecker(OperatorChecker): def get_opeartor_stack_info(api_stack_finder: OpStackFinder, op_name_list: list) -> list: data: Dict[str, Dataset] = {} - event_dataset = TimelineEventDataset(collection_path=profiling_data.collection_path, data=data, task_type=constant.AI_CPU) + event_dataset = ComputationAnalysisDataset(collection_path=profiling_data.collection_path, + data=data, + task_type=constant.AI_CPU) # disable multiprocessing, avoid cost time of enable new process for light task api_stack_finder.get_api_stack_by_op(event_dataset, op_name_list, constant.AI_CPU, disable_multiprocess=True) - return api_stack_finder._stack_record + return api_stack_finder.get_stack_record() self._op_list = [] - total_task_duration = 0.0 + max_task_duration = 0.0 for op_info in op_summary.op_list: + task_duration = float(op_info.task_duration) + if self._check_operator(op_info): self._op_list.append(op_info) + self.aicpu_task_duration += task_duration - task_duration = float(op_info.task_duration) - total_task_duration += task_duration + self.total_task_duration += task_duration max_task_duration = max(max_task_duration, task_duration) if (not self._op_list) or (max_task_duration < self._MIN_TASK_DURATION): return False @@ -135,21 +146,27 @@ class AicpuChecker(OperatorChecker): for op in self._op_list: if not op.has_attr("input_data_types"): logger.warning( - "Skip checking of input data in AICPU checker because of not containing input_data_dtypes in op summary") + "Skip checking of input data in AICPU checker " + "because of not containing input_data_dtypes in op summary") break - if op.has_attr( - "input_data_types") and "DOUBLE" in op.input_data_types and op.op_name not in double_type_ai_cpu_operator: + if (op.has_attr("input_data_types") and "DOUBLE" in op.input_data_types + and op.op_name not in double_type_ai_cpu_operator): double_type_ai_cpu_operator.append(op.op_name) if bool(double_type_ai_cpu_operator): self._SUGGESTION.append("Try to convert double type operator to float, such as {}".format( ",".join(double_type_ai_cpu_operator))) return True - def make_render(self, html_render, record): - html_render.render_template(key="computation", - template_dir="templates", - template_name="operator_ai_cpu.html", - format_result=self.format_operator_result(record, constant.OPERATOR_LIST_UNLIMIT)) + def make_render(self, html_render, record, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="computation", + template_dir="templates", + template_name="operator_ai_cpu.html", + format_result=self.format_operator_result(record, + constant.OPERATOR_LIST_UNLIMIT), + add_render_list=add_render_list, + priority_background_color=priority, + rank=kwargs.get("rank")) def format_operator_result(self, record, limit): """ @@ -163,23 +180,28 @@ class AicpuChecker(OperatorChecker): for suggestion in optimization_item.suggestion: release_suggestion_list.append(suggestion.replace('\n', '
')) logger.debug("suggestion list is %s", release_suggestion_list) - format_result = {"record": record.__dict__, "suggestion": '
'.join(release_suggestion_list), - "task_duration": round(record.statistics_item.task_duration, 2)} + format_result = { + "record": record.__dict__, + "suggestion": '
'.join(release_suggestion_list), + "task_duration": round(record.statistics_item.task_duration, 2), + } statistic = self.group_by(copy.deepcopy(self._op_list), op_key='op_type', limit=limit) format_result["statistic"] = statistic stack_key_list = ["stack_info", "input_data_types", "output_data_types"] if statistic: - for key, info in statistic: + for _, info in statistic: op_info_list = self.group_by_list(info.get("op_info_list"), stack_key_list, limit) info["op_info_list"] = op_info_list return format_result - def group_by_list(self, op_list, op_key_list: List = ["stack_info", "input_data_types", "output_data_types"], + def group_by_list(self, op_list, op_key_list: List = None, limit: int = constant.OPERATOR_LIST_UNLIMIT): if op_list is None: op_list = [] + if op_key_list is None: + op_key_list = ["stack_info", "input_data_types", "output_data_types"] # op_key_list 合并添加合并的属性,作为 groupby 的 key value op_key = '+'.join(op_key_list) # str, json @@ -192,6 +214,14 @@ class AicpuChecker(OperatorChecker): return self.group_by(op_list, op_key=op_key, limit=limit) + def _check_data(self, profiling_data: ProfilingDataset) -> bool: + if not self._check_summary(profiling_data): + return False + return True + + def _check_operator(self, op_info) -> bool: + return op_info.task_type == constant.AI_CPU + class BaserChecker: def __init__(self, *args, **kwargs): @@ -242,6 +272,7 @@ class CommonChecker(BaserChecker): return suggestion.format(",".join(unsupported_dtype_diff).upper(), op_type, ",".join(valid_inputs).upper()) + return None def build(self): for check in self.check_rules: @@ -266,6 +297,7 @@ class ExampleGuideChecker(BaserChecker): if getattr(op_info, 'op_type', "UNKNOWN").lower() in supported_op_type: return suggestion if "{}" not in suggestion else suggestion.format(url) + return None for check in self.check_rules: (_, check_rule), = check.items() diff --git a/profiler/advisor/analyzer/computation/bound/block_dim_checker.py b/profiler/advisor/analyzer/computation/bound/block_dim_checker.py index a7d7ddd93c70e59dc0d10318fdac06fdc581f70c..6eef6f81310c9a186c57b340f163f691f7336d76 100644 --- a/profiler/advisor/analyzer/computation/bound/block_dim_checker.py +++ b/profiler/advisor/analyzer/computation/bound/block_dim_checker.py @@ -1,5 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging - from typing import List from profiler.advisor.analyzer.computation.operator_checker import OperatorChecker @@ -14,6 +27,8 @@ class BlockDimChecker(OperatorChecker): _SUGGESTION: List[str] = [] _CHECKER = "block dim" _PROBLEM = "block dim" + _aicore_num = 0 + _aiv_num = 0 _description = "some operator does not make full use of {} ai core" _ITEMS = [ "op_name", "op_type", "task_type", "task_duration", "income", "block_dim", "mix_block_dim", "input_shapes", @@ -23,22 +38,50 @@ class BlockDimChecker(OperatorChecker): def pre_check(self, profiling_data) -> bool: return not self.is_dynamic_shape(profiling_data) - def _check_data(self, data): - self.format_suggestion_content(data) - if not self._check_summary(data): + def make_render(self, html_render, record, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="computation", + template_dir="templates", + template_name="operator_block_dim.html", + format_result=self.format_operator_result(record, + constant.OPERATOR_OUT_TOPK), + add_render_list=add_render_list, + priority_background_color=priority, + rank=kwargs.get("rank")) + + def get_core_num(self, op_info): + """ + get core num of task type + """ + if op_info.task_type == "AI_CORE" or not self._aiv_num: + core_num = self._aicore_num + else: + core_num = self._aiv_num + return core_num + + def _check_data(self, profiling_data): + self.format_suggestion_content(profiling_data) + if not self._check_summary(profiling_data): return False if not Config().get_config("ai_core_num"): logger.warning(self.SKIP_CHECK_MSG, self._CHECKER, "ai core num in info.json file") return False - summary = data.op_summary + summary = profiling_data.op_summary op_info = summary.op_list[0] if not hasattr(op_info, "block_dim"): logger.warning(self.SKIP_CHECK_MSG, self._CHECKER, "block dim in op summary") return False if Config().get_config("ai_core_num"): - self._aicore_num = int(Config().get_config("ai_core_num")) + try: + self._aicore_num = int(Config().get_config("ai_core_num")) + except ValueError as e: + logger.warning("get ai_core_num failed, please check info.json: %s", e) + return False if Config().get_config("aiv_num"): - self._aiv_num = int(Config().get_config("aiv_num")) + try: + self._aiv_num = int(Config().get_config("aiv_num")) + except ValueError as e: + logger.warning("get aiv_num failed, please check info.json: %s", e) self._description = self._description.format(self._aicore_num) if self._aiv_num: self._description += f" or {self._aiv_num} ai vector core" @@ -46,30 +89,17 @@ class BlockDimChecker(OperatorChecker): "task duration are as follows:\n" return True - def make_render(self, html_render, record): - html_render.render_template(key="computation", - template_dir="templates", - template_name="operator_block_dim.html", - format_result=self.format_operator_result(record, constant.OPERATOR_OUT_TOPK)) - def _check_operator(self, op_info) -> bool: if op_info.task_type not in ["AI_CORE", "AI_VECTOR_CORE", "MIX_AIC"]: return False block_dim = int(op_info.block_dim) core_num = self.get_core_num(op_info) + if core_num == 0: + logger.error("The aicore number is zero. BlockDimChecker is skipped. Please check the info.json file.") + return False if block_dim % core_num == 0: return False - if op_info.task_type == "MIX_AIC" and hasattr(op_info, "mix_block_dim") \ - and self._aiv_num and int(op_info.mix_block_dim) % self._aiv_num == 0: + is_block_dim = op_info.task_type == "MIX_AIC" and hasattr(op_info, "mix_block_dim") + if is_block_dim and self._aiv_num and int(op_info.mix_block_dim) % self._aiv_num == 0: return False return True - - def get_core_num(self, op_info): - """ - get core num of task type - """ - if op_info.task_type == "AI_CORE" or not self._aiv_num: - core_num = self._aicore_num - else: - core_num = self._aiv_num - return core_num diff --git a/profiler/advisor/analyzer/computation/bound/operator_bound_checker.py b/profiler/advisor/analyzer/computation/bound/operator_bound_checker.py index a22b380f974b14207d6d7be262cd49f0ba0fbe99..9ef64e546948945a76a9a3ea7a0a142bd94b2b4d 100644 --- a/profiler/advisor/analyzer/computation/bound/operator_bound_checker.py +++ b/profiler/advisor/analyzer/computation/bound/operator_bound_checker.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from typing import List @@ -27,11 +41,22 @@ class OperatorBoundChecker(OperatorChecker): def pre_check(self, profiling_data) -> bool: return not self.is_dynamic_shape(profiling_data) - def _check_data(self, data): - self.format_suggestion_content(data) - if not self._check_summary(data): + def make_render(self, html_render, record, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="computation", + template_dir="templates", + template_name="operator_no_bound.html", + format_result=self.format_operator_result(record, + constant.OPERATOR_OUT_TOPK), + add_render_list=add_render_list, + priority_background_color=priority, + rank=kwargs.get("rank")) + + def _check_data(self, profiling_data): + self.format_suggestion_content(profiling_data) + if not self._check_summary(profiling_data): return False - for op_info in data.op_summary.op_list: + for op_info in profiling_data.op_summary.op_list: return self._check_operator(op_info) logger.warning(self.SKIP_CHECK_MSG, self._CHECKER, "ratio in op summary") @@ -45,9 +70,3 @@ class OperatorBoundChecker(OperatorChecker): if any(ratio and ratio > Config().operator_bound_ratio for ratio in ratio_list): return False return True - - def make_render(self, html_render, record): - html_render.render_template(key="computation", - template_dir="templates", - template_name="operator_no_bound.html", - format_result=self.format_operator_result(record, constant.OPERATOR_OUT_TOPK)) diff --git a/profiler/advisor/analyzer/computation/op_compile/dynamic_shape_checker.py b/profiler/advisor/analyzer/computation/op_compile/dynamic_shape_checker.py index 86d3bac4ff8cb163d23a6365307b855839b12a6a..4b65d5a48f73673162a8b9a2bda7386e7318e266 100644 --- a/profiler/advisor/analyzer/computation/op_compile/dynamic_shape_checker.py +++ b/profiler/advisor/analyzer/computation/op_compile/dynamic_shape_checker.py @@ -1,9 +1,24 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import logging from typing import List from profiler.advisor.analyzer.computation.operator_checker import OperatorChecker from profiler.advisor.common import constant +from profiler.advisor.config.config import Config from profiler.advisor.dataset.profiling.info_collection import OpInfo from profiler.advisor.result.item import OptimizeItem, StatisticsItem, OptimizeRecord @@ -11,8 +26,9 @@ logger = logging.getLogger() class DynamicShapeChecker(OperatorChecker): - ENABLE_COMPILED_SUGGESTION = "Optimize by enabling compiled operator, such as:\n" \ - "`torch_npu.npu.set_compile_mode(jit_compile=False)`\n" + ENABLE_COMPILED_SUGGESTION = "Please place the following code at the entrance of the python script to disable jit compile.\n " \ + "Code: `torch_npu.npu.set_compile_mode(jit_compile=False);\n " \ + "torch_npu.npu.config.allow_internal_format = False`" _SUGGESTION: List[str] = [ENABLE_COMPILED_SUGGESTION] _CHECKER = "dynamic shape operator" _PROBLEM = "Dynamic shape operator" @@ -24,14 +40,16 @@ class DynamicShapeChecker(OperatorChecker): def __init__(self, cann_version) -> None: super().__init__(cann_version=cann_version) - def check(self, profiling_database) -> bool: - return self.is_dynamic_shape(profiling_database) + def check(self, profiling_data) -> bool: + return self.is_dynamic_shape(profiling_data) - def make_record(self, profiling_database) -> OptimizeRecord: + def make_record(self, profiling_data, rank=None) -> OptimizeRecord: """ make record for what and how to optimize """ + if rank is not None: + self._PROBLEM = f"rank {rank} ".capitalize() + self._PROBLEM.lower() optimization_item = OptimizeItem( self._PROBLEM, self._description, @@ -53,13 +71,17 @@ class DynamicShapeChecker(OperatorChecker): release_suggestion = copy.deepcopy(suggestion) if release_suggestion == DynamicShapeChecker.ENABLE_COMPILED_SUGGESTION: release_suggestion += \ - f"for details please refer to link : LINK" + f"for details please refer to link : LINK" release_suggestion_list.append(release_suggestion.replace('\n', '
')) format_result = {"record": record.__dict__, "suggestion": '
'.join(release_suggestion_list)} return format_result - def make_render(self, html_render, record): - html_render.render_template(key="computation", - template_dir="templates", - template_name="operator_dynamic_shape.html", - format_result=self.format_operator_result(record)) + def make_render(self, html_render, record, add_render_list=True, **kwargs): + priority = kwargs.get("priority") + return html_render.render_template(key="computation", + template_dir="templates", + template_name="operator_dynamic_shape.html", + format_result=self.format_operator_result(record), + add_render_list=add_render_list, + priority_background_color=priority, + rank=kwargs.get("rank")) diff --git a/profiler/advisor/analyzer/computation/operator_checker.py b/profiler/advisor/analyzer/computation/operator_checker.py index 64618b56a8df7f380277e99ae7ca47cd69d24648..17be15b4eb547e1d2fce198fd0eeaef071b8ad31 100644 --- a/profiler/advisor/analyzer/computation/operator_checker.py +++ b/profiler/advisor/analyzer/computation/operator_checker.py @@ -1,21 +1,36 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import copy import logging from textwrap import fill from typing import List from profiler.advisor.common import constant +from profiler.advisor.common.enum_params_parser import EnumParamsParser from profiler.advisor.common.version_control import VersionControl from profiler.advisor.config.config import Config from profiler.advisor.dataset.profiling.info_collection import OpInfo from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset from profiler.advisor.result.item import OptimizeItem, StatisticsItem, OptimizeRecord -from profiler.advisor.utils.utils import safe_division +from profiler.advisor.utils.utils import safe_division, convert_to_float logger = logging.getLogger() class OperatorChecker(VersionControl): - _SUPPORT_VERSIONS = constant.SUPPORTED_CANN_VERSION + _SUPPORT_VERSIONS = EnumParamsParser().get_options(constant.CANN_VERSION) _MAX_TUNE_OP_NUM = constant.OPERATOR_OUT_TOPK _MIN_TASK_DURATION = 0 _MIN_TASK_DURATION_RATIO = 1.0 @@ -33,12 +48,30 @@ class OperatorChecker(VersionControl): f"--tune_ops_file={Config().tune_ops_file}'\n" MSLite_OPERATOR_TUNE_SUGGESTION = f"Optimize operator by AOE in mindspore lite framework, such as:\n" \ f"converter_lite --fmk=ONNX --optimize=ascend_oriented --saveType=MINDIR " \ - f"--modelFile=$user_model.onnx --outputFile=user_model --configFile=./config.txt\n" - _tune_op_list: List[str] = [] + f"--modelFile=$user_model.onnx --outputFile=user_model " \ + f"--configFile=./config.txt\n" def __init__(self, cann_version: str): self.cann_version = cann_version self._op_list: List[OpInfo] = [] + self._tune_op_list: List[str] = [] + + @staticmethod + def get_ratio(op_info: OpInfo, attr: str) -> float: + if not op_info.has_attr(attr): + return 0 + value = op_info.get_attr(attr) + if not value or value == "N/A": + return 0 + return float(value) + + @classmethod + def get_name(cls): + """ + get name of checker + :return: checker name + """ + return cls._PROBLEM def check(self, profiling_data: ProfilingDataset) -> bool: """ @@ -77,14 +110,19 @@ class OperatorChecker(VersionControl): return True return False - def make_record(self, profiling_data: ProfilingDataset): + def make_record(self, profiling_data: ProfilingDataset, rank=None): """ Make record for what and how to optimize :param profiling_data: profiling data :return: optimize record """ - task_duration_list = [float(op_info.get_attr("task_duration")) for op_info in self._op_list if - hasattr(op_info, "get_attr")] + + if rank is not None: + self._PROBLEM = f"rank {rank} ".capitalize() + self._PROBLEM.lower() + + task_duration_list = [float(op_info.get_attr("task_duration")) + for op_info in self._op_list + if hasattr(op_info, "get_attr")] total_cost_time = sum(task_duration_list) total_task_duration = profiling_data.op_summary.get_total_task_duration() count = len(task_duration_list) @@ -101,7 +139,7 @@ class OperatorChecker(VersionControl): return description desc_suffix = [] - for i in range(len(op_type_list)): + for i, _ in enumerate(op_type_list): if i % 3 == 0 and i != 0: desc_suffix.append("\n") @@ -117,7 +155,11 @@ class OperatorChecker(VersionControl): return True def is_dynamic_shape(self, profiling_database: ProfilingDataset) -> bool: - less_than_cann800_list = [constant.CANN_VERSION_C30, constant.CANN_VERSION_C13, constant.CANN_VERSION_C15] + cann800_major_version = 8 + less_than_cann800_list = EnumParamsParser().get_options( + constant.CANN_VERSION, + filter_func=lambda x: convert_to_float(x.split(".")[0]) < cann800_major_version + ) # CANN 8.0.RC1 之前从 ge_info 中获取 op_state 属性,进行动态 shape 逻辑判断 if self.cann_version in less_than_cann800_list: if hasattr(profiling_database, "ge_info"): @@ -128,8 +170,9 @@ class OperatorChecker(VersionControl): else: logger.warning( "Skip dynamic shape check because of not containing ge_info.db file in host filefloder.\n" - "To enable dynamic shape check, please try to set data_simplification=False in experimental_config.\n" - "More details please refer to link : %s", constant.ASCEND_PROFILER_URL) + "To enable dynamic shape check, " + "please try to set data_simplification=False in experimental_config.\n" + "More details please refer to link : %s", Config().ascend_profiler_url) else: # CANN 8.0.RC1 之后 op_state 属性从 op_summary 文件中获取 if hasattr(profiling_database, "op_summary"): @@ -138,8 +181,8 @@ class OperatorChecker(VersionControl): return True else: logger.warning( - "Skip dynamic shape check because of not containing op_summary.csv file in current filefloder." - ) + "Skip dynamic shape check because of not containing op_summary.csv file in current filefloder." + ) return False def format_operator_result(self, record, limit): @@ -155,18 +198,21 @@ class OperatorChecker(VersionControl): release_suggestion = copy.deepcopy(suggestion) if release_suggestion == OperatorChecker.PyTorch_OPERATOR_TUNE_SUGGESTION: release_suggestion += \ - (f"for details please refer to link : LINK") + (f"for details please refer to link : LINK") elif release_suggestion == OperatorChecker.MSLite_OPERATOR_TUNE_SUGGESTION: release_suggestion += \ (f"\nThe config file for MSLite AOE usage is as follows:\n" \ f"[ascend_context]\n" \ f"aoe_mode=\"operator tuning\"\n" \ f"--tune_ops_file={Config().tune_ops_file}\n" - f"\nFor details please refer to link : LINK") + f"\nFor details please refer to link : LINK") release_suggestion_list.append(release_suggestion.replace('\n', '
')) - format_result = {"record": record.__dict__, - "suggestion": fill('
'.join(release_suggestion_list), width=200), - "task_duration": round(record.statistics_item.task_duration, 2)} + format_result = { + "record": record.__dict__, + "suggestion": fill('
'.join(release_suggestion_list), width=200), + "task_duration": round(record.statistics_item.task_duration, 2), + } statistic = self.group_by(copy.deepcopy(self._op_list), limit=limit) format_result["statistic"] = statistic return format_result @@ -184,15 +230,21 @@ class OperatorChecker(VersionControl): op_list = [] statistic = {} # str, json for op_info in op_list: - if statistic.get(op_info.get_attr(op_key)): - statistic[op_info.get_attr(op_key)]["summary"]["total_duration"] = float( - statistic[op_info.get_attr(op_key)]["summary"]["total_duration"]) + float( - op_info.get_attr("task_duration", constant.DEFAULT_DURATION_ZERO)) - statistic[op_info.get_attr(op_key)]["summary"]["counts"] += 1 + statistic_op_key = statistic.get(op_info.get_attr(op_key), {}) + summary = statistic_op_key.get("summary", {}) + if summary: + if summary.get("total_duration"): + summary["total_duration"] = float( + summary["total_duration"]) + float( + op_info.get_attr("task_duration", constant.DEFAULT_DURATION_ZERO)) + if summary.get("counts"): + summary["counts"] += 1 stack_info = op_info.get_attr("stack_info") if stack_info: op_info.stack_info = stack_info.replace('\r\n', '
') - statistic[op_info.get_attr(op_key)]["op_info_list"].append(op_info) + if statistic_op_key.get("op_info_list") is None: + statistic_op_key["op_info_list"] = [] + statistic_op_key["op_info_list"].append(op_info) else: statistic[op_info.get_attr(op_key)] = {"summary": {}, "op_info_list": []} statistic[op_info.get_attr(op_key)]["summary"]["op_type"] = op_info.get_attr( @@ -219,15 +271,6 @@ class OperatorChecker(VersionControl): logger.warning("%s checker do not has results to format html", str(self.__class__.__name__)) return statistic - def _check_data(self, profiling_data): - return True - - def _check_operator(self, op_info): - return False - - def _get_income(self, _op_info: OpInfo) -> float: - return 0 - def get_tune_op_list(self): """ get tune op list @@ -239,14 +282,6 @@ class OperatorChecker(VersionControl): """Get node views.""" return [] - @classmethod - def get_name(cls): - """ - get name of checker - :return: checker name - """ - return cls._PROBLEM - def get_incomes(self) -> float: """get incomes""" incomes = 0.0 @@ -264,21 +299,6 @@ class OperatorChecker(VersionControl): op_type_list.append(op_info.op_type) return op_type_list - def _check_summary(self, data: ProfilingDataset): - if not hasattr(data, "op_summary"): - logger.warning(self.SKIP_CHECK_MSG, self._CHECKER, "op summary") - return False - return True - - @staticmethod - def get_ratio(op_info: OpInfo, attr: str) -> float: - if not op_info.has_attr(attr): - return 0 - value = op_info.get_attr(attr) - if not value or value == "N/A": - return 0 - return float(value) - def get_details(self) -> list: """ get details of operator to be optimized @@ -301,7 +321,22 @@ class OperatorChecker(VersionControl): return details def format_suggestion_content(self, profiling_data: ProfilingDataset) -> None: - if profiling_data.PROF_TYPE == constant.ASCEND_PYTORCH_PROFILER: + if profiling_data.PROF_TYPE == EnumParamsParser().profiling_type.ascend_pytorch_profiler: self._SUGGESTION.append(self.PyTorch_OPERATOR_TUNE_SUGGESTION) - elif profiling_data.PROF_TYPE == constant.MSLITE: + elif profiling_data.PROF_TYPE == EnumParamsParser.profiling_type.mslite: self._SUGGESTION.append(self.MSLite_OPERATOR_TUNE_SUGGESTION) + + def _check_data(self, profiling_data): + return True + + def _check_operator(self, op_info): + return False + + def _get_income(self, _op_info: OpInfo) -> float: + return 0 + + def _check_summary(self, data: ProfilingDataset): + if not hasattr(data, "op_summary"): + logger.warning(self.SKIP_CHECK_MSG, self._CHECKER, "op summary") + return False + return True \ No newline at end of file diff --git a/profiler/advisor/analyzer/computation/pp_stage_computation_analyzer.py b/profiler/advisor/analyzer/computation/pp_stage_computation_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..97bfff8abbb552c23ffab84d0c7504774ef231e4 --- /dev/null +++ b/profiler/advisor/analyzer/computation/pp_stage_computation_analyzer.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from multiprocessing import Manager + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.common.analyzer_scopes import SupportedScopes +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.interface.interface import Interface +from profiler.advisor.utils.utils import ParallelJob, get_analyze_processes +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.common import constant as const + +logger = logging.getLogger() + + +class PPStageComputationAnalyzer(BaseAnalyzer): + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.collection_path = collection_path + self._stages_rendered_html = Manager().list() + self._multiprocess_result = Manager().dict() + # html render不能序列化,无法用多进程,放到optimize里面初始化 + self.html_render = None + self.result = None + + @staticmethod + def _get_valid_sheet_name(sheet_name, prefix): + if not sheet_name.lower().startswith(prefix.lower()): + sheet_name = f"{prefix} {sheet_name}" + return sheet_name + + def optimize(self, stages_profiling_path, **kwargs): + pp_stage_processes = min(get_analyze_processes(), len(stages_profiling_path)) + if pp_stage_processes <= 1: + for stage_profiling_path in stages_profiling_path: + self._optimize(**stage_profiling_path) + else: + logger.info("Start to parallel analysis of pp stages, number of processes is %s", pp_stage_processes) + parallel_stage_analysis_job = ParallelJob(self._optimize, stages_profiling_path, + "Computation analysis of Pipeline parallel stages") + parallel_stage_analysis_job.start(pp_stage_processes) + self._merge_multiprocess_result() + + self.make_render() + self.html_render = HTMLRender() + return self.result + + def make_render(self): + HTMLRender().render_template(key="computation", + template_dir="templates", + template_name="pp_stage_computation_analysis.html", + stages_rendered_html=list(self._stages_rendered_html), + priority_background_color=PriorityBackgroundColor.high) + + def get_priority(self): + pass + + def _optimize(self, profiling_path, **kwargs): + stage_html_record = dict(stage=kwargs.get("stage"), rank=kwargs.get("rank"), step=kwargs.get("step")) + kwargs["add_render_list"] = False + + # stage 并行分析时,避免调用本身,即SupportedScopes.STAGE_COMPUTE + scopes = Interface.get_scope(Interface.COMPUTATION) + stage_analyzer_list = [Interface.get_analyzer(Interface.COMPUTATION, scope) + for scope in scopes + if scope != SupportedScopes.STAGE_COMPUTE] + + for analyzer_cls in stage_analyzer_list: + analyzer = analyzer_cls(collection_path=profiling_path, **kwargs) + result = analyzer.optimize(**kwargs) + if hasattr(result, "data") and result.data: + self.result = result + if hasattr(analyzer, "html") and analyzer.html: + if "html_list" not in stage_html_record: + stage_html_record["html_list"] = [] + stage_html_record["html_list"].append(analyzer.html) + self._stages_rendered_html.append(stage_html_record) + self._multiprocess_result[f"rank {kwargs.get('rank')}".capitalize()] = result.data + + def _merge_multiprocess_result(self): + self.result = OptimizeResult() + for key, result_data in self._multiprocess_result.items(): + problem_data = result_data.get("problems", {}).get("data", []) + if not problem_data: + continue + + for row in problem_data: + if len(row) < 3: + continue + issue_name, desc, suggestion = row[:3] + sheet_name = PPStageComputationAnalyzer._get_valid_sheet_name(issue_name, key) + optimization_item = OptimizeItem(sheet_name, desc, [suggestion]) + self.result.add(OptimizeRecord(optimization_item)) + del result_data["problems"] + + for issue_name, issue_details in result_data.items(): + headers = issue_details.get("headers", []) + data = issue_details.get("data", []) + sheet_name = PPStageComputationAnalyzer._get_valid_sheet_name(issue_name, key) + self.result.add_detail(sheet_name, headers=headers) + + for row in data: + self.result.add_detail(sheet_name, detail=row) diff --git a/profiler/advisor/analyzer/computation/profiling_analyzer.py b/profiler/advisor/analyzer/computation/profiling_analyzer.py index 8682617700702055628a31982b0eafab9feb336d..6d525f303cc8c5971bda8a11d16d638ef3dcf2c3 100644 --- a/profiler/advisor/analyzer/computation/profiling_analyzer.py +++ b/profiler/advisor/analyzer/computation/profiling_analyzer.py @@ -1,19 +1,30 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from abc import ABC -from typing import Dict, List from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer -from profiler.advisor.common import constant from profiler.advisor.result.result import OptimizeResult from profiler.advisor.analyzer.computation.aicpu.aicpu_checker import AicpuChecker from profiler.advisor.analyzer.computation.bound.block_dim_checker import BlockDimChecker from profiler.advisor.analyzer.computation.bound.operator_bound_checker import OperatorBoundChecker -from profiler.advisor.analyzer.computation.operator_checker import OperatorChecker from profiler.advisor.analyzer.computation.op_compile.dynamic_shape_checker import DynamicShapeChecker from profiler.advisor.analyzer.computation.operator_checker import OperatorChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor from profiler.advisor.display.html.render import HTMLRender from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset -from profiler.advisor.utils.utils import get_supported_subclass logger = logging.getLogger() @@ -26,6 +37,7 @@ class ProfilingAnalyzer(BaseAnalyzer, ABC): self.checker = OperatorChecker(self.cann_version) self.html_render = HTMLRender() self.result = OptimizeResult() + self.html = None @BaseAnalyzer.check_data((ProfilingDataset.get_key(),)) def optimize(self, **kwargs) -> OptimizeResult: @@ -36,22 +48,29 @@ class ProfilingAnalyzer(BaseAnalyzer, ABC): """ profiling_data = self.get_first_data_by_key(self.dataset_list, ProfilingDataset.get_key()) checker = self.checker + rank = kwargs.get("rank") + + add_render_list = kwargs.get("add_render_list", True) + if not checker.pre_check(profiling_data): return self.result if checker.check(profiling_data): # add record - record = checker.make_record(profiling_data) - checker.make_render(self.html_render, record) + record = checker.make_record(profiling_data, rank) + self.html = checker.make_render(self.html_render, record, add_render_list, + priority=self.get_priority(checker), rank=kwargs.get("rank")) self.result.add(record) # add details details = checker.get_details() if details: for i, detail in enumerate(details): + sheet_name = checker.get_name() if rank is None else \ + f"rank {rank} ".capitalize() + checker.get_name() if i == 0: # the first row is header - self.result.add_detail(checker.get_name(), headers=detail) + self.result.add_detail(sheet_name, headers=detail) else: - self.result.add_detail(checker.get_name(), detail=detail) + self.result.add_detail(sheet_name, detail=detail) # add tune op list tune_op_list = checker.get_tune_op_list() if tune_op_list: @@ -59,11 +78,13 @@ class ProfilingAnalyzer(BaseAnalyzer, ABC): return self.result - def make_record(self): - pass + def get_priority(self,max_mem_op_dur): + if "aicpu" not in max_mem_op_dur.__class__.__name__.lower(): + return PriorityBackgroundColor.low - def make_render(self): - pass + aicpu_duration = getattr(max_mem_op_dur, "aicpu_task_duration", 0.0) + total_duration = getattr(max_mem_op_dur, "total_task_duration", 0.0) + return self.get_priority_by_time_ratio(aicpu_duration, total_duration) class DynamicShapeAnalyzer(ProfilingAnalyzer): @@ -76,14 +97,15 @@ class BlockDimAnalyzer(ProfilingAnalyzer): def __init__(self, collection_path, **kwargs) -> None: super().__init__(collection_path, **kwargs) self.checker = BlockDimChecker(self.cann_version) - + class OperatorBoundAnalyzer(ProfilingAnalyzer): def __init__(self, collection_path, **kwargs) -> None: super().__init__(collection_path, **kwargs) self.checker = OperatorBoundChecker(self.cann_version) + class AicpuAnalyzer(ProfilingAnalyzer): def __init__(self, collection_path, **kwargs) -> None: super().__init__(collection_path, **kwargs) - self.checker = AicpuChecker(self.cann_version) \ No newline at end of file + self.checker = AicpuChecker(self.cann_version) diff --git a/profiler/advisor/analyzer/dataloader/dataloader_analyzer.py b/profiler/advisor/analyzer/dataloader/dataloader_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..5c97773ef26247214814f490600e4a986890ec26 --- /dev/null +++ b/profiler/advisor/analyzer/dataloader/dataloader_analyzer.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from typing import List, Dict, Any + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.dataloader.dataloader_checker import DataloaderChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset + +logger = logging.getLogger() + + +class DataloaderAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = ScheduleAnalysisDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + dataloader_checker = DataloaderChecker() + dataloader_checker.check_slow_dataloader(self.dataset) + dataloader_checker.make_record(self.result) + dataloader_checker.make_render(self.html_render, priority=self.get_priority(), rank=kwargs.get("rank")) + return self.result + + def get_priority(self): + return PriorityBackgroundColor.high diff --git a/profiler/advisor/analyzer/dataloader/dataloader_checker.py b/profiler/advisor/analyzer/dataloader/dataloader_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ba7713c7070e60a44dd47934350440eeb1f2f2 --- /dev/null +++ b/profiler/advisor/analyzer/dataloader/dataloader_checker.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import logging +import yaml + +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +class DataloaderChecker: + + def __init__(self): + + self.dataloader_issues = False + self.optimization_item = [] + self.desc = "" + self.suggestions = [] + self.dataloader_duration_threshold = None + self._init_rule() + + def check_slow_dataloader(self, event_dataset: ScheduleAnalysisDataset): + """ + :Param event_dataset: dataset of timeline event + """ + if not hasattr(event_dataset, "dataloader") or not getattr(event_dataset, "dataloader"): + logger.debug("Skip slow dataloader checker, because no dataloader duration larger than %s", + self.dataloader_duration_threshold) + return + for event in event_dataset.dataloader: + + dataloader_duration = float(event.dur) + if dataloader_duration < self.dataloader_duration_threshold: + continue + self.desc = self.desc.format(dataloader_duration=dataloader_duration, + dataloader_duration_threshold=self.dataloader_duration_threshold) + self.dataloader_issues = True + + if re.search("singleprocess", event.name.lower()): + self.suggestions = self._reset_suggestions(["I/O", "num_workers"]) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.dataloader_issues: + return + + self.optimization_item.append(OptimizeItem("Slow dataloader", self.desc, self.suggestions)) + for optimization in self.optimization_item: + result.add(OptimizeRecord(optimization)) + + def make_render(self, html_render, **kwargs): + if not self.dataloader_issues: + return + priority = kwargs.get("priority") + html_render.render_template(key="dataloader", + template_dir="templates", + template_name="slow_dataloader.html", + desc=self.desc, + suggestions=self.suggestions, + priority_background_color=priority, + rank=kwargs.get("rank")) + + def _init_rule(self): + dataloader_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "dataloader.yaml" + ) + dataloader_rule = FileManager.read_yaml_file(dataloader_rule_path) + + self.dataloader_duration_threshold = dataloader_rule.get("dataloader_duration_threshold") + self.desc = dataloader_rule.get("problem") + self.suggestions = dataloader_rule.get("solutions") + + def _reset_suggestions(self, suggestion_pattern_list): + + suggestions = [] + for solution in self.suggestions: + for suggestion_pattern in suggestion_pattern_list: + if re.search(suggestion_pattern, solution): + suggestions.append(solution) + return suggestions diff --git a/profiler/advisor/analyzer/graph_fusion/graph_fusion_analyzer.py b/profiler/advisor/analyzer/graph_fusion/graph_fusion_analyzer.py index 326be83b8d49088b1563ccd8c08b68a4aa3001ef..b72e1316a452303020e08f16a80a28c11717115f 100644 --- a/profiler/advisor/analyzer/graph_fusion/graph_fusion_analyzer.py +++ b/profiler/advisor/analyzer/graph_fusion/graph_fusion_analyzer.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List from functools import partial @@ -20,17 +34,22 @@ class FusionOPAnalyzer(BaseAnalyzer): super(FusionOPAnalyzer, self).__init__(collection_path, **kwargs) self.result = OptimizeResult() self.html_render = HTMLRender() - + self.html = None + @BaseAnalyzer.check_data((GraphDataset.get_key(),)) def optimize(self, **kwargs): """ :return: result """ - self._check(self.dataset_list.get("GraphDataset"), self.dataset_list.get("ProfilingDataset")) + self._check(self.dataset_list.get("GraphDataset"), self.dataset_list.get("ProfilingDataset"), + kwargs.get("add_render_list")) return self.result - def _check(self, graph_data: List[GraphDataset], - profiling_data: List[ProfilingDataset] = None) -> None: + def get_priority(self): + pass + + def _check(self, graph_data: List[GraphDataset], profiling_data: List[ProfilingDataset] = None, + add_render_list=True) -> None: if len(graph_data) == 0 or graph_data[0].is_empty(): return for _, rule in self.RULES.items(): @@ -40,10 +59,4 @@ class FusionOPAnalyzer(BaseAnalyzer): else: checker.find_fusion_matched_issues_with_times(graph_data, profiling_data) checker.make_record(self.result) - checker.make_render(self.html_render) - - def make_record(self): - pass - - def make_render(self): - pass + self.html = checker.make_render(self.html_render) \ No newline at end of file diff --git a/profiler/advisor/analyzer/graph_fusion/graph_fusion_checker.py b/profiler/advisor/analyzer/graph_fusion/graph_fusion_checker.py index e64020fdfe2ace37172e82ed562db1b66971d3d6..2cfde931a6116db41f1ed3bec2f17f64cd88ddeb 100644 --- a/profiler/advisor/analyzer/graph_fusion/graph_fusion_checker.py +++ b/profiler/advisor/analyzer/graph_fusion/graph_fusion_checker.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from typing import List @@ -27,6 +41,23 @@ class GraphFusionRules: graph.build() yield graph + @staticmethod + def get_attr_shape(node, type_name: str, attr_name: str) -> str: + attr_shape = [] + node_attrs = getattr(node, type_name, []) + for attrs in node_attrs: + attr = getattr(attrs, attr_name, []) + attr_shape.append(",".join(attr)) + return ";".join(attr_shape) + + @staticmethod + def get_attr_type(node, type_name: str, attr_name: str) -> str: + attr_type = [] + node_attrs = getattr(node, type_name, []) + for attr in node_attrs: + attr_type.append(getattr(attr, attr_name, "")) + return ";".join(attr_type) + def find_fusion_matched_issues(self, graphs: List[GraphDataset]): query_graphs = QueryGraphParser(self.fusion_rules) with tqdm(total=query_graphs.num_rules, leave=False, ncols=100, unit=" rules") as pbar: @@ -106,8 +137,11 @@ class GraphFusionRules: has_time_info = False if self.task_duration_list: has_time_info = True - candidate_dict['total_duration'] = round(sum(sum(duration) for duration in - self.task_duration_list[case_id]), 2) + candidate_dict['total_duration'] = round( + sum( + sum(duration) + for duration in self.task_duration_list[case_id] + ), 2) for node_index, refer_node in enumerate(nodes): match = [] index = 0 @@ -149,7 +183,7 @@ class GraphFusionRules: optimization_item = OptimizeItem( "fusion issue", f"Found {len(self.candidates)} fusion issues", - ["Check fusion issues detail in att_advisor*.html"] + ["Check fusion issues detail in mstt_advisor*.html"] ) total_time = 0.0 for candidate in self.task_duration_list: @@ -188,20 +222,3 @@ class GraphFusionRules: self.get_attr_type(host_node, "output", "dtype"), ] result.add_detail('fusion issues', detail=detail) - - @staticmethod - def get_attr_shape(node, type_name: str, attr_name: str) -> str: - attr_shape = [] - node_attrs = getattr(node, type_name, []) - for attrs in node_attrs: - attr = getattr(attrs, attr_name, []) - attr_shape.append(",".join(attr)) - return ";".join(attr_shape) - - @staticmethod - def get_attr_type(node, type_name: str, attr_name: str) -> str: - attr_type = [] - node_attrs = getattr(node, type_name, []) - for attr in node_attrs: - attr_type.append(getattr(attr, attr_name, "")) - return ";".join(attr_type) diff --git a/profiler/advisor/analyzer/memory/__init__.py b/profiler/advisor/analyzer/memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/memory/memory_analyzer.py b/profiler/advisor/analyzer/memory/memory_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..939e2de90c634ee6cca584dca345111dce26bb7b --- /dev/null +++ b/profiler/advisor/analyzer/memory/memory_analyzer.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.memory.memory_checker import MemoryOpsChecker +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor + +logger = logging.getLogger() + + +class MemoryAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = ScheduleAnalysisDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + memory_checker = MemoryOpsChecker() + memory_checker.check_memory_ops(self.dataset) + memory_checker.make_record(self.result) + memory_checker.make_render(self.html_render, priority=self.get_priority(memory_checker.max_mem_op_dur), rank=kwargs.get("rank")) + return self.result + + def get_priority(self, max_mem_op_dur): + step_duration = getattr(self.dataset, "step_duration", None) + if step_duration is None: + return PriorityBackgroundColor.high + ratio = self.get_priority_by_time_ratio(max_mem_op_dur, step_duration) + + return ratio diff --git a/profiler/advisor/analyzer/memory/memory_checker.py b/profiler/advisor/analyzer/memory/memory_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..b446067ef7e6cd6ddfbd3c61f59bb49632c014c1 --- /dev/null +++ b/profiler/advisor/analyzer/memory/memory_checker.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import logging +import yaml + +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset, MemCollector +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +class MemoryOpsChecker: + + def __init__(self): + + self.memory_issues = False + self.optimization_item = [] + self.desc = "" + self.suggestions = [] + self.memory_ops_duration_threshold = None + self.max_mem_op_dur = 0 + + def check_memory_ops(self, event_dataset: ScheduleAnalysisDataset): + """ + :Param event_dataset: dataset of timeline event + """ + if not hasattr(event_dataset, "memory_ops") or not getattr(event_dataset, "memory_ops") or \ + not event_dataset.memory_ops.mem_op_info: + logger.debug("Skip slow memory ops checker, because no memory ops: %s", MemCollector.MEMORY_OP_NAME) + return + + rule = event_dataset.memory_ops.rule + max_dur_thres = rule.get("max_total_duration") + raw_problem = rule.get("problem") + + for memory_op_name, memory_op_info in event_dataset.memory_ops.mem_op_info.items(): + op_dur = memory_op_info.get("total_dur") + op_count = memory_op_info.get("count") + if op_dur < max_dur_thres: + continue + if op_dur > self.max_mem_op_dur: + self.max_mem_op_dur = op_dur + + self.memory_issues = True + self.desc += raw_problem.format(memory_op_num=op_count, memory_op_name=memory_op_name, + memory_op_dur=op_dur) + " " + for solution in rule.get("solutions", []): + if memory_op_name not in solution: + continue + suggestions = solution.get(memory_op_name, {}).get("desc") + for suggestion in suggestions: + self.suggestions.append(f"For {memory_op_name}: {suggestion}") + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.memory_issues: + return + + self.optimization_item.append(OptimizeItem("Memory", self.desc, self.suggestions)) + for optimization in self.optimization_item: + result.add(OptimizeRecord(optimization)) + + def make_render(self, html_render, **kwargs): + if not self.memory_issues: + return + priority = kwargs.get("priority") + html_render.render_template(key="memory", + template_dir="templates", + template_name="memory.html", + desc=self.desc, + suggestions=self.suggestions, + priority_background_color=priority, + rank=kwargs.get("rank")) diff --git a/profiler/advisor/analyzer/overall/environment_variable_analyzer.py b/profiler/advisor/analyzer/overall/environment_variable_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..c4468c36d0eded6b36ae265e239d95e1fdf2dbbb --- /dev/null +++ b/profiler/advisor/analyzer/overall/environment_variable_analyzer.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.prof_common.path_manager import PathManager +from profiler.advisor.dataset.environment_variable_dataset import EnvironmentVariableDataset +from profiler.advisor.analyzer.overall.environment_variable_checker import EnvironmentVariabelChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor + + +class EnvironmentVariabelAnalyzer(BaseAnalyzer): + dataset_cls_list = [EnvironmentVariableDataset] + + def __init__(self, collection_path: str, n_processes: int = 1, **kwargs): + super().__init__(collection_path, n_processes, **kwargs) + self.dataset = self.get_first_data_by_key(self.dataset_list, EnvironmentVariableDataset.get_key()) + + def optimize(self, **kwargs): + try: + PathManager.check_input_directory_path(self.collection_path) + except RuntimeError as e: + logging.error("Invalid path: %s", str(e)) + return self.result + self.collection_path = PathManager.get_realpath(self.collection_path) + checker = EnvironmentVariabelChecker() + checker.format_env_suggest(self.dataset) + checker.make_record(self.result) + checker.make_render(self.html_render) + return self.result + + def get_priority(self): + return PriorityBackgroundColor.high + + def make_record(self): + pass + + def make_render(self): + pass diff --git a/profiler/advisor/analyzer/overall/environment_variable_checker.py b/profiler/advisor/analyzer/overall/environment_variable_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..25058a790cc10c03b658309590666e90a29b450e --- /dev/null +++ b/profiler/advisor/analyzer/overall/environment_variable_checker.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem +from profiler.advisor.result.item import OptimizeRecord +from profiler.advisor.common.analyzer_scopes import SupportedScopes +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.utils.utils import convert_to_int + + +class EnvironmentVariabelChecker: + ENV_SUGGEST_CONDITION = { + "ASCEND_GLOBAL_LOG_LEVEL": lambda x: x != "" and convert_to_int(x) != 3, + "HCCL_RDMA_TC": lambda x: x != "", + "HCCL_RDMA_SL": lambda x: x != "", + "ACLNN_CACHE_LIMIT": lambda x: x == "" or convert_to_int(x) < 10000, + "HOST_CACHE_CAPACITY": lambda x: x == "" or convert_to_int(x) == 0, + "ASCEND_ENHANCE_ENABLE": lambda x: convert_to_int(x) == 0, + "PYTORCH_NPU_ALLOC_CONF": lambda x: isinstance(x, str) and "expandable_segments:True" not in x, + "ASCEND_LAUNCH_BLOCKING": lambda x: convert_to_int(x) != 1, + } + + HEADERS = ["Environment", "Value", "Description", "Suggestion"] + + def __init__(self): + self.environment_info = self.read_environment_info() + self.env_suggest_csv = [] + self.env_suggest_html = [] + + @staticmethod + def read_environment_info(): + environment_variable_info_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "environment_variable_info.yaml" + ) + return FileManager.read_yaml_file(environment_variable_info_path) + + def format_env_suggest(self, data): + data = data.env_data.get('ENV_VARIABLES', {}) + for env, value in data.items(): + if not self.ENV_SUGGEST_CONDITION.get(env, lambda x: False)(value): + continue + desc = self.environment_info.get(env, {}).get("desc", "") + suggest = self.environment_info.get(env, {}).get("suggest", "") + self.env_suggest_csv += [ + [ + env, + value, + desc, + suggest, + ] + ] + self.env_suggest_html += [ + [ + env, + value, + desc.replace('\n', '
'), + self.environment_info.get(env, {}).get("suggest_html", suggest), + ] + ] + + def make_record(self, result: OptimizeResult): + if not self.env_suggest_csv: + return + desc = f"Describe and suggest the optimal environment variable settings" + suggestion = "Please set the optimal environment variable" + + optimization_item = OptimizeItem( + SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS, + desc, + [suggestion] + ) + result.add(OptimizeRecord(optimization_item)) + result.add_detail(SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS, headers=self.HEADERS) + for env_suggest in self.env_suggest_csv: + result.add_detail(SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS, detail=env_suggest) + + def make_render(self, html_render: HTMLRender): + if not self.env_suggest_html: + return + html_render.render_template(key="overall", + template_dir="templates", + template_name="environment_variable.html", + result={ + "headers": self.HEADERS, + "data": self.env_suggest_html, + }) diff --git a/profiler/advisor/analyzer/overall/overall_analyzer.py b/profiler/advisor/analyzer/overall/overall_analyzer.py deleted file mode 100644 index 916a396b3d096dc788954cbc8e8ba9755cd15f4e..0000000000000000000000000000000000000000 --- a/profiler/advisor/analyzer/overall/overall_analyzer.py +++ /dev/null @@ -1,45 +0,0 @@ -import logging -from typing import Dict, List - -from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer -from profiler.advisor.display.html.render import HTMLRender -from profiler.advisor.result.result import OptimizeResult -from profiler.compare_tools.compare_backend.utils.constant import Constant -from profiler.compare_tools.compare_interface.comparison_interface import ComparisonInterface - -logger = logging.getLogger() - - -class OverallSummaryAnalyzer(BaseAnalyzer): - - def __init__(self, profiling_path, benchmark_profiling_path=None, **kwargs): - self.benchmark_profiling_path = benchmark_profiling_path or profiling_path - self.profiling_path = profiling_path - self.html_render = HTMLRender() - self.result = OptimizeResult() - - def optimize(self, **kwargs): - compare_result = ComparisonInterface(self.benchmark_profiling_path, self.profiling_path).compare( - Constant.OVERALL_COMPARE) - - headers = compare_result.get('Model Profiling Time Distribution').get("headers", []) - rows = compare_result.get('Model Profiling Time Distribution').get("rows", []) - - self.make_record() - self.make_render(headers=headers, rows=rows) - return compare_result - - def make_record(self): - pass - - def make_render(self, **kwargs): - headers = kwargs.get("headers") - rows = kwargs.get("rows") - - if not headers or not rows: - logger.info("Empty headers or rows, skip render overall analysis html") - self.html_render.render_template(key="overall", - template_dir="templates", - template_name="overall_analysis.html", - headers=kwargs.get("headers"), - rows=kwargs.get("rows")) diff --git a/profiler/advisor/analyzer/overall/overall_summary_analyzer.py b/profiler/advisor/analyzer/overall/overall_summary_analyzer.py index c74ae0510331fb9ba8a1794bd724710ba19cfabf..8a5982d3ce92f4401fd0537c08d0176b42a02471 100644 --- a/profiler/advisor/analyzer/overall/overall_summary_analyzer.py +++ b/profiler/advisor/analyzer/overall/overall_summary_analyzer.py @@ -12,28 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import copy - import logging -from typing import Dict, List +import os +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer from profiler.advisor.display.html.render import HTMLRender -from profiler.advisor.result.result import OptimizeResult from profiler.advisor.result.item import OptimizeItem, OptimizeRecord -from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult from profiler.compare_tools.compare_backend.utils.constant import Constant -from profiler.advisor.common import constant as const from profiler.compare_tools.compare_interface.comparison_interface import ComparisonInterface -from profiler.advisor.utils.utils import get_file_path_from_directory, load_parameter class OverallSummaryAnalyzer(BaseAnalyzer): - OVERALL_SUMMARY_ANALYZER = "overall_summary_analysis" + OVERALL_SUMMARY_ANALYZER = "overall summary" advice_map = { - "Computing Time": "if you want more detailed advice please go to att_advisor_*.html", - "Uncovered Communication Time": "if you want more detailed advice please go to att_advisor_*.html", - "Free Time": "if you want more detailed advice please go to att_advisor_*.html" + "Computing Time": "if you want more detailed advice please go to mstt_advisor_*.html", + "Uncovered Communication Time": "if you want more detailed advice please go to mstt_advisor_*.html", + "Free Time": "if you want more detailed advice please go to mstt_advisor_*.html" } time_name_map = { "Computing Time": "computing", @@ -47,45 +42,37 @@ class OverallSummaryAnalyzer(BaseAnalyzer): 'SDMA Time(Num)': 'SDMA Time' } performance_time_dict = { - "Computing Time": ['Cube Time(Num)', 'Vector Time(Num)', 'Flash Attention Time(Forward)(Num)', - 'Flash Attention Time(Backward)(Num)', 'Other Time'], - "Uncovered Communication Time(Wait Time)": [], - "Free Time": ['SDMA Time(Num)'] + "Computing Time": "computing_time_ms", + " -- Flash Attention": "fa_time_ms", + " -- Conv": "conv_time_ms", + " -- Matmul": "matmul_time_ms", + " -- Vector": "vector_time_ms", + " -- SDMA(Tensor Move)": "tensor_move_time_ms", + " -- Other Cube": "other_cube_time_ms", + "Uncovered Communication Time": "uncovered_communication_time_ms", + " -- Wait": "wait_time_ms", + " -- Transmit": "transmit_time_ms", + "Free Time": "free_time_ms", + " -- SDMA": "sdma_time_ms", + " -- Free": "free_ms", + "E2E Time": "e2e_time_ms" } def __init__(self, collection_path: str, n_processes: int = 1, **kwargs): profile_path = get_profile_path(collection_path) super().__init__(profile_path, n_processes, **kwargs) - self.base_collection_path = kwargs.get("base_collection_path", "") - self._has_base_collection = False + self.benchmark_profiling_path = kwargs.get("benchmark_profiling_path", "") + self._has_benchmark_profiling = False self._is_minimal_profiling = False self.cur_data = {} - self.cur_data_table = {} self.cur_bottleneck = {} + self._disaggregate_perf = {} + self._disaggregate_benchmark_perf = {} self.cur_advices = "" - self._headers = [] - self._base_data = [] - self._comparison_data = [] self.html_render = HTMLRender() self.result = OptimizeResult() self.bottleneck_str = "" - self.bottleneck_table = {} - - @staticmethod - def split_duration_and_num(time_value: str) -> tuple: - split_data = time_value.split("s") # time value example: 0.229s(1756) - duration, num = 0.0, None - if len(split_data) >= 2: - try: - num = int(split_data[1].strip("()")) - except ValueError: - pass - if len(split_data) >= 1: - try: - duration = float(split_data[0]) - except ValueError: - print(f"[WARNING] Invalid time value: {time_value}.") - return duration, num + self.over_summary_analysis = {} @staticmethod def calculate_ratio(dividend, divisor): @@ -93,131 +80,121 @@ class OverallSummaryAnalyzer(BaseAnalyzer): return float("inf") return dividend / divisor + @staticmethod + def get_time_category_dict(overall_dict: dict): + time_category_dict = { + "Computing Time": round(overall_dict.get('computing_time_ms', 0.0), 3), + "Uncovered Communication Time": round(overall_dict.get('uncovered_communication_time_ms', 0.0), 3), + "Free Time": round(overall_dict.get('free_time_ms', 0.0), 3) + } + return time_category_dict + def path_check(self): - if self.base_collection_path: - if os.path.exists(self.base_collection_path): - self._has_base_collection = True + if self.benchmark_profiling_path: + if os.path.exists(self.benchmark_profiling_path): + self._has_benchmark_profiling = True else: - print(f"[WARNING] Invalid path which not exists: {self.base_collection_path}.") + logging.warning("Invalid path which not exists: %s.", self.benchmark_profiling_path) return os.path.exists(self.collection_path) def process(self): - base_collection_path = self.base_collection_path if self._has_base_collection else self.collection_path - result_data = ComparisonInterface(base_collection_path, self.collection_path).compare(Constant.OVERALL_COMPARE) - for data in result_data.values(): - self._headers = data.get("headers", []) - rows = data.get("rows", []) - if len(rows) == 2: - self._base_data = rows[0] - self._comparison_data = rows[1] - if not self._headers or not self._comparison_data: + self._disaggregate_perf = ComparisonInterface(self.collection_path).disaggregate_perf(Constant.OVERALL_COMPARE) + if not self._disaggregate_perf: return - self._is_minimal_profiling = 'E2E Time(Not minimal profiling)' not in self._headers - if self._has_base_collection: - self.cur_data["comparison_result"] = result_data - time_category_dict = {} - for time_category, time_list in self.performance_time_dict.items(): - time_value = self.get_time_value(time_category, self._comparison_data) - if time_value == Constant.INVALID_VALUE: - continue - duration, _ = self.split_duration_and_num(time_value) - time_category = time_category.split("(")[0] - time_category_dict[time_category] = duration - self.get_sub_category_time(time_category, time_list, duration) - self.cur_data["overall_data"] = time_category_dict - - def get_time_value(self, header_name: str, data_list: list): - try: - data_index = self._headers.index(header_name) - except ValueError: - return Constant.INVALID_VALUE - try: - time_value = data_list[data_index] - except IndexError: - return Constant.INVALID_VALUE - return time_value - - def get_sub_category_time(self, category: str, time_list: list, total_duration: float): - sub_time_dict = {} - for time_name in time_list: - time_value = self.get_time_value(time_name, self._comparison_data) - if time_value == Constant.INVALID_VALUE: - continue - sub_time_dict.setdefault(f"{category} Subtype", []).append(self.time_name_map.get(time_name, "")) - duration, num = self.split_duration_and_num(time_value) - sub_time_dict.setdefault(f"Duration(s)", []).append(duration) - sub_time_dict.setdefault(f"Duration Ratio", []).append( - "{:.2%}".format(self.calculate_ratio(duration, total_duration))) - sub_time_dict.setdefault(f"Kernel Number", []).append(num) - self.cur_data[self.time_name_map.get(category)] = sub_time_dict + self._is_minimal_profiling = self._disaggregate_perf.get("minimal_profiling", False) + self.cur_data["overall_data"] = self.get_time_category_dict(self._disaggregate_perf.get('overall', {})) + if self._has_benchmark_profiling: + self._disaggregate_benchmark_perf = ComparisonInterface( + self.benchmark_profiling_path).disaggregate_perf(Constant.OVERALL_COMPARE) def identify_bottleneck(self): overall_data = self.cur_data.get("overall_data") if not overall_data: return e2e_time = '%.3f' % sum([data for data in overall_data.values()]) - overall_bottleneck = f"The Model E2E Time is {e2e_time}s.\n" + overall_bottleneck = f"The Model E2E Time is {e2e_time}ms.\n" comparison_bottleneck = "" for time_type, time_value in overall_data.items(): - # add subtype time bottleneck - self.cur_bottleneck[self.time_name_map.get(time_type)] = f"{time_type} is {time_value}s.\n" # add overall bottleneck - overall_bottleneck += f" -- {time_type} is {time_value}s\n" + overall_bottleneck += f" -- {time_type} is {time_value}ms\n" if time_type == "Free Time" and self._is_minimal_profiling and self.calculate_ratio(time_value, e2e_time) > 0.1: overall_bottleneck += "percentage of free time exceed the threshold 10%." - if not self._has_base_collection: + if not self._has_benchmark_profiling: continue # add comparison bottleneck - time_type_origin = "Uncovered Communication Time(Wait Time)" \ - if time_type == "Uncovered Communication Time" else time_type - base_duration, _ = self.split_duration_and_num(self.get_time_value(time_type_origin, self._base_data)) + base_duration = self.get_time_category_dict( + self._disaggregate_benchmark_perf.get('overall', {}) + ).get(time_type) if time_value > base_duration: ratio = "{:.2%}".format(self.calculate_ratio(time_value - base_duration, base_duration)) comparison_bottleneck += f"{time_type} exceeds the benchmark by {ratio}\n" self.cur_bottleneck["overall_data"] = overall_bottleneck if comparison_bottleneck: self.cur_bottleneck["comparison_result"] = comparison_bottleneck + def optimize(self, **kwargs): if self.path_check(): self.process() self.identify_bottleneck() self.format_bottleneck() - self.format_cur_data() + self.format_over_summary_analysis() self.make_record() self.make_render() return self.result def format_bottleneck(self): result = '' - headers = [] - data_list = [] - data = [] - for key, value in self.cur_bottleneck.items(): + for _, value in self.cur_bottleneck.items(): if not value: continue - result += f'{key}: {value} \n' - headers.append(key) - data.append(value) - data_list.append(data) + result += f'{value} \n' self.bottleneck_str = result - self.bottleneck_table["headers"] = headers - self.bottleneck_table["data"] = data_list - def format_cur_data(self): - if not self.cur_data: - return - for data_type, data in self.cur_data.items(): - if not data: - continue - if data_type not in list(self.time_name_map.values()): - data_list = list(data.values()) - else: - data_list = [','.join(map(str, value)) for value in data.values()] - headers = list(data.keys()) - data_table = {"headers": headers, "data": [data_list]} - self.cur_data_table[data_type] = copy.deepcopy(data_table) + def format_over_summary_analysis(self): + headers = ['Performance Index', 'Duration(ms)', 'Duration Ratio'] + performance_data = self.get_analysis_data(self._disaggregate_perf) + benchmark_data = self.get_analysis_data(self._disaggregate_benchmark_perf) + if self._has_benchmark_profiling: + headers.append('Diff Duration(ms)') + self.format_analysis_with_benchmark(performance_data, benchmark_data, headers) + else: + self.format_analysis_only(performance_data, headers) + + def get_analysis_data(self, data_dict: dict): + if not data_dict: + return {} + return { + **data_dict.get("overall"), + **data_dict.get("computing_time_disaggregate"), + **data_dict.get("communication_time_disaggregate"), + **data_dict.get("free_time_disaggregate"), + } + def format_analysis_only(self, performance_data: dict, headers: list): + res = [] + total_duration = performance_data.get('e2e_time_ms', 0.0) + for time_name, time_key in self.performance_time_dict.items(): + row = [time_name] + duration = performance_data.get(time_key, 0.0) + row.append("{:.3f}".format(duration)) + row.append("{:.2%}".format(self.calculate_ratio(duration, total_duration))) + res.append(row) + self.over_summary_analysis["headers"] = headers + self.over_summary_analysis["data"] = res + + def format_analysis_with_benchmark(self, performance_data: dict, benchmark_data: dict, headers: list): + res = [] + total_duration = performance_data.get('e2e_time_ms', 0.0) + for time_name, time_key in self.performance_time_dict.items(): + row = [time_name] + duration = performance_data.get(time_key, 0.0) + row.append("{:.3f}".format(duration)) + row.append("{:.2%}".format(self.calculate_ratio(duration, total_duration))) + row.append("{:.3f}".format(duration - benchmark_data.get(time_key, 0.0))) + res.append(row) + self.over_summary_analysis["headers"] = headers + self.over_summary_analysis["data"] = res def make_record(self): """ @@ -232,20 +209,23 @@ class OverallSummaryAnalyzer(BaseAnalyzer): ) self.result.add(OptimizeRecord(optimization_item)) - self.result.add_detail(const.BOTTLENECK, self.bottleneck_table["headers"], self.bottleneck_table["data"][0]) - for data_type, data_dict in self.cur_data_table.items(): - if data_dict: - self.result.add_detail(const.DATA + data_type, data_dict["headers"], data_dict["data"][0]) + self.result.add_detail( + OverallSummaryAnalyzer.OVERALL_SUMMARY_ANALYZER, + headers=self.over_summary_analysis["headers"] + ) + for data in self.over_summary_analysis["data"]: + self.result.add_detail(OverallSummaryAnalyzer.OVERALL_SUMMARY_ANALYZER, detail=data) def make_render(self): if not self.bottleneck_str and not self.cur_advices: return + # 将\n替换为html换行 + bottleneck_str = self.bottleneck_str.replace('\n', '
') result_for_html = { - "Description" : self.bottleneck_str, - "suggestion" : self.cur_advices, - "details" : [self.bottleneck_table] + "Description": bottleneck_str, + "suggestion": self.cur_advices, + "details": [self.over_summary_analysis] } - self.html_render.render_template(key="overall", title=OverallSummaryAnalyzer.OVERALL_SUMMARY_ANALYZER, template_dir="templates", @@ -254,9 +234,13 @@ class OverallSummaryAnalyzer(BaseAnalyzer): torch_version=self.torch_version, result=result_for_html) + def get_priority(self): + pass + + def get_profile_path(collection_path): - for root, dirs, files in os.walk(collection_path): + for root, _, files in os.walk(collection_path): for file in files: if file.startswith("profiler_info"): return root - return "" \ No newline at end of file + return "" diff --git a/profiler/advisor/analyzer/schedule/dispatch/timeline_op_dispatch_analyzer.py b/profiler/advisor/analyzer/schedule/dispatch/timeline_op_dispatch_analyzer.py index 0e62a3ff0c8eebc0cf7b5b89953b8a0842df9c9d..126fe30176cf6ca0f1d7d3557c360f95af7b20be 100644 --- a/profiler/advisor/analyzer/schedule/dispatch/timeline_op_dispatch_analyzer.py +++ b/profiler/advisor/analyzer/schedule/dispatch/timeline_op_dispatch_analyzer.py @@ -16,26 +16,26 @@ # limitations under the License. import logging - from profiler.advisor.common import constant as const from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer -from profiler.advisor.dataset.timeline_event_dataset import TimelineEventDataset +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset from profiler.advisor.result.item import OptimizeItem, OptimizeRecord from profiler.advisor.result.result import OptimizeResult from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor logger = logging.getLogger() class OpDispatchAnalyzer(BaseAnalyzer): - dataset_cls_list = [TimelineEventDataset] + dataset_cls_list = [ScheduleAnalysisDataset] """ operator dispatch optimizer """ def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: super().__init__(collection_path, n_processes, **kwargs) - key = TimelineEventDataset.get_key() + key = ScheduleAnalysisDataset.get_key() self.dataset = self.get_first_data_by_key(self.dataset_list, key) self.result = OptimizeResult() self.html_render = HTMLRender() @@ -51,24 +51,24 @@ class OpDispatchAnalyzer(BaseAnalyzer): """ self.get_op_compile_info(self.dataset) self.make_record(self.result) - self.make_render(self.html_render) + self.make_render(self.html_render, rank=kwargs.get('rank')) return self.result - def get_op_compile_info(self, event_dataset: TimelineEventDataset): - """ - :Param event_dataset: dataset of timeline event - """ - if hasattr(event_dataset, "ops_compile"): - self._op_compile = getattr(event_dataset, "ops_compile") - if not self._op_compile or self._op_compile.total_count < const.MAX_OP_COMPILE_NUM: - return + def get_op_compile_info(self, event_dataset: ScheduleAnalysisDataset): + """ + :Param event_dataset: dataset of timeline event + """ + if hasattr(event_dataset, "ops_compile"): + self._op_compile = getattr(event_dataset, "ops_compile") + if not self._op_compile or self._op_compile.total_count < const.MAX_OP_COMPILE_NUM: + return - self._issues_record.append(['operator dispatch', - const.OP_COMPILE_ID, - self._op_compile.total_count, - self._op_compile.total_time]) - else: - logger.debug("Skip operator compile checker, because no op_compile attr find.") + self._issues_record.append(['operator dispatch', + const.OP_COMPILE_ID, + self._op_compile.total_count, + self._op_compile.total_time]) + else: + logger.debug("Skip operator compile checker, because no op_compile attr find.") def make_record(self, result: OptimizeResult): """ @@ -77,8 +77,9 @@ class OpDispatchAnalyzer(BaseAnalyzer): if not self._op_compile or len(self._issues_record) <= 0: return desc = f"Found {self._op_compile.total_count} operator compile issues." - suggestion = (f"Please use `torch_npu.npu.set_compile_mode(jit_compile=False)` to disable jit compile " - f"in dynamic shape usage.") + suggestion = ("Please place the following code at the entrance of the python script to disable jit compile. " \ + "Code: `torch_npu.npu.set_compile_mode(jit_compile=False); " + "torch_npu.npu.config.allow_internal_format = False`") self.optimization_item.append(OptimizeItem("Operator dispatch", desc, [suggestion])) for optimization in self.optimization_item: result.add(OptimizeRecord(optimization)) @@ -87,7 +88,7 @@ class OpDispatchAnalyzer(BaseAnalyzer): for op_info in self._issues_record: result.add_detail('operator dispatch', detail=op_info) - def make_render(self, html_render): + def make_render(self, html_render, **kwargs): issues = [] optimizations = [] for optimization in self.optimization_item: @@ -97,11 +98,21 @@ class OpDispatchAnalyzer(BaseAnalyzer): )) for record in self._issues_record: issues.append(dict(issue=record[0], - op_name=record[1], - counts=record[2], - total_time=record[3])) + op_name=record[1], + counts=record[2], + total_time=record[3])) html_render.render_template(key="schedule", template_dir="templates", template_name="operator_dispatch.html", issues=issues, - optimizers=optimizations) + optimizers=optimizations, + priority_background_color=self.get_priority(), + rank=kwargs.get("rank")) + + def get_priority(self): + step_duration = getattr(self.dataset, "step_duration", None) + op_compile_total_dur = getattr(self._op_compile, "total_time", None) + if step_duration is None or op_compile_total_dur is None: + return PriorityBackgroundColor.low + + return self.get_priority_by_time_ratio(op_compile_total_dur, step_duration) diff --git a/profiler/advisor/analyzer/schedule/fusion_ops/fusion_ops_analyzer.py b/profiler/advisor/analyzer/schedule/fusion_ops/fusion_ops_analyzer.py index c1eb24b8e1e11ac167a7eb9333867167a57dd524..7407823106ec6039605e87539e86e66f737e20f4 100644 --- a/profiler/advisor/analyzer/schedule/fusion_ops/fusion_ops_analyzer.py +++ b/profiler/advisor/analyzer/schedule/fusion_ops/fusion_ops_analyzer.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os import multiprocessing import logging import re @@ -8,26 +23,37 @@ from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer from profiler.advisor.common import constant as const from profiler.advisor.common.analyzer_scopes import SupportedScopes from profiler.advisor.common.timeline.event import TimelineEvent -from profiler.advisor.dataset.timeline_event_dataset import TimelineEventDataset +from profiler.advisor.config.config import Config +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset from profiler.advisor.result.item import OptimizeItem, OptimizeRecord from profiler.advisor.utils.utils import format_timeline_result from profiler.advisor.common.timeline.fusion_ops_db import init_timeline_ops_db +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor logger = logging.getLogger() class TimelineFusionOpsAnalyzer(BaseAnalyzer): - dataset_cls_list = [TimelineEventDataset] + dataset_cls_list = [ScheduleAnalysisDataset] def __init__(self, collection_path, n_processes: int = 1, **kwargs): super().__init__(collection_path, n_processes, **kwargs) self._matched_op_index = {} if self.n_processes <= 1 else multiprocessing.Manager().dict() self.matched_op_stacks = {} self.empty_stacks = True - key = TimelineEventDataset.get_key() + key = ScheduleAnalysisDataset.get_key() self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) + def get_priority(self): + return PriorityBackgroundColor.low + def optimize(self, **kwargs): + disable_affinity_api = os.getenv(const.DISABLE_AFFINITY_API) + if disable_affinity_api is not None and disable_affinity_api.lower() == "true": + logger.info( + "Skip affinity api analysis due to longer processing time due to env 'DISABLE_AFFINITY_API'") + return self.result + for mode in [const.ATEN.lower(), const.OPTIMIZER.lower()]: for op_combined, npu_apis in tqdm(getattr(init_timeline_ops_db(self.cann_version, self.torch_version), @@ -40,7 +66,7 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): logger.info("Finish timeline analysis") self.make_record() - self.make_render() + self.make_render(rank=kwargs.get("rank")) return self.result def find_fusion_ops(self, event_dataset, ops: str, npu_api: str, mode: str): @@ -60,6 +86,74 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): except Exception as e: logger.warning("Failed to find fusion operators with regex %s, reason is %s", ops, e) + def make_record(self): + """ + make record for what and how to optimize + """ + if not self.matched_op_stacks: + return + + desc = f"Found {len(format_timeline_result(self.matched_op_stacks))} apis to be replaced" \ + f" based on the runtime env cann-{self.cann_version} and torch-{self.torch_version}" + suggestion = "Please replace training api according to sub table 'Affinity training api'" + if self.empty_stacks: + desc += ", but with no stack" + suggestion = const.TIMELINE_EMPTY_STACKS_PROMPT.format( + timeline_profiling_doc_url=Config().timeline_with_stack_doc_url + ) + + sheet_name = "Affinity apis" + optimization_item = OptimizeItem( + sheet_name, + desc, + [suggestion] + ) + + self.result.add(OptimizeRecord(optimization_item)) + + record_title = ["Affinity API", "Code stacks", "Stack called counts"] + self.result.add_detail(sheet_name, headers=record_title) + + for api_name, stacks_info in format_timeline_result(self.matched_op_stacks).items(): + if not stacks_info: + detail = [api_name, "null", "null"] + self.result.add_detail(sheet_name, detail=detail) + else: + for stack in stacks_info: + detail = [api_name, *stack] + self.result.add_detail(sheet_name, detail=detail) + + def make_render(self, **kwargs): + rank = kwargs.get("rank") + format_result_for_html = format_timeline_result(dict(self.matched_op_stacks), dump_html=True) + + self.html_render.render_template(key="schedule", + template_dir="templates", + template_name="affinity_api.html", + cann_version=self.cann_version, + torch_version=self.torch_version, + empty_stacks=self.empty_stacks, + with_stack_doc_url=Config().timeline_with_stack_doc_url, + api_doc_url=Config().timeline_api_doc_url, + result=format_result_for_html, + priority_background_color=self.get_priority(), + rank=rank) + + def query_stack(self, event_dataset): + if all([len(matched_index) == 0 for matched_index in self._matched_op_index.values()]): + return + + op_stack_list = event_dataset.parse_data_with_generator(self._query_stack_by_matched_index) + for op_stack in op_stack_list: + for op_rule, stack in op_stack.items(): + if op_rule not in self.matched_op_stacks: + self.matched_op_stacks[op_rule] = {} + if stack == const.TIMELINE_FUSION_OPS_NO_STACK_FLAG: + continue + if stack not in self.matched_op_stacks[op_rule]: + self.matched_op_stacks[op_rule][stack] = 0 + self.matched_op_stacks[op_rule][stack] += 1 + def _match_ops(self, event_dataset, ops: str, npu_api: str, mode: str): """ match operator based on fusion operators rule(without regex), only strictly equals of op name list means matched @@ -76,8 +170,8 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): for index, event in enumerate(getattr(event_dataset, mode)): if self._replace_op_name_prefix(event.name, mode) != op_list[0]: continue - tmp_dequeue_event_names = [self._replace_op_name_prefix(event.name, mode) for event in - getattr(event_dataset, mode)[index: index + len(op_list)]] + tmp_dequeue_event_names = [self._replace_op_name_prefix(event.name, mode) + for event in getattr(event_dataset, mode)[index: index + len(op_list)]] if tmp_dequeue_event_names != op_list: continue api_ops_matched = True @@ -97,12 +191,13 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): """ matched_op_index = set() total_op_name = "".join([f"{const.OP_SEP}{self._replace_op_name_prefix(event.name, mode)}{const.OP_SEP}" - for event in - getattr(event_dataset, mode)]) + for event in getattr(event_dataset, mode)]) matched_pattern_index_tuple = [(x.start(0), x.end(0)) for x in re.finditer(op_rule_pattern, total_op_name)] # convert list of index tuple to a whole list: [(3, 25), ...] -> [3, 25, ...] - total_ops_split_points = [num for sublist in matched_pattern_index_tuple for num in sublist] + total_ops_split_points = [num + for sublist in matched_pattern_index_tuple + for num in sublist] api_ops_matched = len(total_ops_split_points) != 0 @@ -114,9 +209,9 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): # convert total ops name like "-add-mul-xxx-div-" to small pieces like [["add", "mul"], [...], ["div"]] # by the regex index and then calculate the real index for matched fusion operators in event dataset - for l, r in zip(total_ops_split_points, total_ops_split_points[1:]): - matched_op_flag = True if (l, r) in matched_pattern_index_tuple else False - matched_ops_list = total_op_name[l: r].strip(const.OP_SEP).split(const.OP_SEP + const.OP_SEP) + for left, right in zip(total_ops_split_points, total_ops_split_points[1:]): + matched_op_flag = True if (left, right) in matched_pattern_index_tuple else False + matched_ops_list = total_op_name[left: right].strip(const.OP_SEP).split(const.OP_SEP + const.OP_SEP) op_index.append([matched_op_flag, len(matched_ops_list)]) for i, _ in enumerate(op_index): if i > 0: @@ -138,70 +233,6 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): if api_ops_matched: self._matched_op_index[npu_api + f":{op_rule_pattern}"] = sorted(list(matched_op_index)) - def make_record(self): - """ - make record for what and how to optimize - """ - if not self.matched_op_stacks: - return - - desc = f"Found {len(format_timeline_result(self.matched_op_stacks))} apis to be replaced" \ - f" based on the runtime env cann-{self.cann_version} and torch-{self.torch_version}" - suggestion = "Please replace training api according to sub table 'Affinity training api'" - if self.empty_stacks: - desc += ", but with no stack" - suggestion = const.TIMELINE_EMPTY_STACKS_PROMPT.format( - timeline_profiling_doc_url=const.TIMELINE_WITH_STACK_DOC_URL - ) - - optimization_item = OptimizeItem( - SupportedScopes.TIMELINE_FUSION_OPS, - desc, - [suggestion] - ) - - self.result.add(OptimizeRecord(optimization_item)) - - record_title = ["Affinity API", "Code stacks", "Stack called counts"] - self.result.add_detail(SupportedScopes.TIMELINE_FUSION_OPS, headers=record_title) - - for api_name, stacks_info in format_timeline_result(self.matched_op_stacks).items(): - if not stacks_info: - detail = [api_name, "null", "null"] - self.result.add_detail(SupportedScopes.TIMELINE_FUSION_OPS, detail=detail) - else: - for stack in stacks_info: - detail = [api_name, *stack] - self.result.add_detail(SupportedScopes.TIMELINE_FUSION_OPS, detail=detail) - - def make_render(self): - format_result_for_html = format_timeline_result(dict(self.matched_op_stacks), dump_html=True) - - self.html_render.render_template(key="schedule", - template_dir="templates", - template_name="affinity_api.html", - cann_version=self.cann_version, - torch_version=self.torch_version, - empty_stacks=self.empty_stacks, - with_stack_doc_url=const.TIMELINE_WITH_STACK_DOC_URL, - api_doc_url=const.TIMELINE_API_DOC_URL, - result=format_result_for_html) - - def query_stack(self, event_dataset): - if all([len(matched_index) == 0 for matched_index in self._matched_op_index.values()]): - return - - op_stack_list = event_dataset.parse_data_with_generator(self._query_stack_by_matched_index) - for op_stack in op_stack_list: - for op_rule, stack in op_stack.items(): - if op_rule not in self.matched_op_stacks: - self.matched_op_stacks[op_rule] = {} - if stack == const.TIMELINE_FUSION_OPS_NO_STACK_FLAG: - continue - if stack not in self.matched_op_stacks[op_rule]: - self.matched_op_stacks[op_rule][stack] = 0 - self.matched_op_stacks[op_rule][stack] += 1 - def _query_stack_by_matched_index(self, index, event): stack_record = {} event = TimelineEvent(event) @@ -255,7 +286,7 @@ class TimelineFusionOpsAnalyzer(BaseAnalyzer): op_pattern_list = op_rule.split(const.OP_SEP) format_op_pattern = "" for op_pattern in op_pattern_list: - matched_res = re.search(r'\((.*?)\)', op_pattern) + matched_res = re.search(r'\((\w*?)\)', op_pattern) ops_index_range = (matched_res.start() + 1, matched_res.end() - 1) if matched_res else ( 0, len(op_pattern)) diff --git a/profiler/advisor/analyzer/schedule/fusion_ops/timeline_api_stack_checker.py b/profiler/advisor/analyzer/schedule/fusion_ops/timeline_api_stack_checker.py index f684a4892111f113f6c502a010c9e14ccd43768a..126584c3d733401ee7a3481b794f46bf93fc51aa 100644 --- a/profiler/advisor/analyzer/schedule/fusion_ops/timeline_api_stack_checker.py +++ b/profiler/advisor/analyzer/schedule/fusion_ops/timeline_api_stack_checker.py @@ -1,9 +1,23 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from typing import List from profiler.advisor.common import constant as const from profiler.advisor.common.timeline.event import TimelineEvent -from profiler.advisor.dataset.timeline_event_dataset import TimelineEventDataset +from profiler.advisor.dataset.timeline_event_dataset import ComputationAnalysisDataset from profiler.advisor.result.result import OptimizeResult from profiler.advisor.result.item import OptimizeItem, OptimizeRecord from profiler.advisor.utils.utils import get_analyze_processes, ParallelJob @@ -21,7 +35,25 @@ class OpStackFinder: self.task_type = None self.matched_index = set() - def get_api_stack_by_op(self, event_dataset: TimelineEventDataset, op_name: List[str] = None, task_type: str = None, + @staticmethod + def _query_index_by_torch_to_npu(event_dataset, torch_to_npu_event): + dst_op_event_key = torch_to_npu_event.ts + dst_op_event = event_dataset.ops_with_stack.get(dst_op_event_key) + + if not dst_op_event: + return const.TIMELINE_BACKWARD_NO_STACK_CODE + + return int(dst_op_event.get("dataset_index")) + + @staticmethod + def _query_index_by_acl_to_npu(acl_to_npu_event): + if acl_to_npu_event: + return const.TIMELINE_ACL_TO_NPU_NO_STACK_CODE + + return const.TIMELINE_BACKWARD_NO_STACK_CODE + + def get_api_stack_by_op(self, event_dataset: ComputationAnalysisDataset, op_name: List[str] = None, + task_type: str = None, disable_multiprocess=False): """ :Param event_dataset: dataset of timeline event @@ -82,7 +114,16 @@ class OpStackFinder: for op_info in self._stack_record: result.add_detail('operator stacks', detail=op_info) - def _get_api_stack_by_op(self, event_dataset: TimelineEventDataset, op_name: str, task_type: str): + def query_stack(self, event_dataset: ComputationAnalysisDataset): + + if not event_dataset.dataset_len: + return + _ = event_dataset.parse_data_with_generator(self._query_stack_by_matched_index) + + def get_stack_record(self): + return self._stack_record + + def _get_api_stack_by_op(self, event_dataset: ComputationAnalysisDataset, op_name: str, task_type: str): for _, src_op_event in event_dataset.ops_with_task_type.items(): op_task_type = src_op_event.get(const.TASK_TYPE) @@ -110,24 +151,12 @@ class OpStackFinder: task_id = src_op_event.task_id if not task_id: continue + self.matched_index.add(dst_op_index) if dst_op_index not in self._task_id_record: self._task_id_record[dst_op_index] = [] self._task_id_record[dst_op_index].append([task_id, op_name, task_type]) - def _query_index_by_torch_to_npu(self, event_dataset, torch_to_npu_event): - dst_op_event_key = torch_to_npu_event.ts - dst_op_event = event_dataset.ops_with_stack.get(dst_op_event_key) - - if not dst_op_event: - return const.TIMELINE_BACKWARD_NO_STACK_CODE - - return dst_op_event.get("dataset_index") - - def _query_index_by_acl_to_npu(self, acl_to_npu_event): - if acl_to_npu_event: - return const.TIMELINE_ACL_TO_NPU_NO_STACK_CODE - def _query_stacks_multiprocess(self, event_dataset, op_name_list, task_type): for op_name in op_name_list: @@ -148,6 +177,7 @@ class OpStackFinder: return None event = TimelineEvent(event) stack = event.args.get(const.CALL_STACKS) + stack = stack if stack else const.NO_STACK_REASON_MAP.get(const.TIMELINE_BACKWARD_NO_STACK_CODE) for matched_op_info in self._task_id_record.get(index, []): self._stack_record.append([*matched_op_info, stack]) @@ -156,8 +186,3 @@ class OpStackFinder: self._stack_record.append([*matched_op_info, const.NO_STACK_REASON_MAP.get(const.TIMELINE_ACL_TO_NPU_NO_STACK_CODE)]) return None - - def query_stack(self, event_dataset: TimelineEventDataset): - if not event_dataset.dataset_len: - return - _ = event_dataset.parse_data_with_generator(self._query_stack_by_matched_index) diff --git a/profiler/advisor/analyzer/schedule/gc/__init__.py b/profiler/advisor/analyzer/schedule/gc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/schedule/gc/gc_analyzer.py b/profiler/advisor/analyzer/schedule/gc/gc_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..b59a8fc2e2a25428e34cb51917462f9f6162bc46 --- /dev/null +++ b/profiler/advisor/analyzer/schedule/gc/gc_analyzer.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.schedule.gc.gc_checker import GcChecker +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor + +logger = logging.getLogger() + + +class GcAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.result = OptimizeResult() + self.html_render = HTMLRender() + key = ScheduleAnalysisDataset.get_key() + self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + gc_checker = GcChecker() + gc_checker.check_gc(self.timeline_event_dataset, rank=kwargs.get("rank"), stage=kwargs.get("stage")) + gc_checker.make_record(self.result) + gc_checker.make_render(self.html_render, priority=self.get_priority(), rank=kwargs.get("rank")) + return self.result + + def get_priority(self): + return PriorityBackgroundColor.medium diff --git a/profiler/advisor/analyzer/schedule/gc/gc_checker.py b/profiler/advisor/analyzer/schedule/gc/gc_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..be1a60536774e849e425d3a9f0b724001274132f --- /dev/null +++ b/profiler/advisor/analyzer/schedule/gc/gc_checker.py @@ -0,0 +1,185 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +import os + +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.utils.utils import convert_to_float, convert_to_int, safe_division +from profiler.advisor.common import constant as const +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +class GcChecker: + + def __init__(self): + self.stage = None + self.rank = None + self.optimization_item = [] + self.gc_issues = False + self.gc_problem_with_count = "" + self.gc_problem_with_free = "" + self.desc = "" + self.suggestions = [] + self.solutions = None + self.gc_threshold = 0 + self.gc_topk_num = 0 + self.abnormal_gc_count = 0 + self.abnormal_gc_duration = 0 + self.abnormal_gc_list = [] + self.headers = ["timestamp", "duration(us)"] + self._init_rule() + + def check_gc(self, event_dataset: ScheduleAnalysisDataset, rank=None, stage=None): + """ + :Param event_dataset: dataset of timeline event + """ + self.rank = rank + self.stage = stage + + # 当用户cann和pta版本不支持采集gc信息时,通过timeline中的free和cann层acl事件 综合判断是否可能存在free + if not event_dataset.gc_events: + acl_events = getattr(event_dataset, "acl_events", []) + large_free_events = getattr(event_dataset, "large_free_events", []) + # 如果acl_events为空,则没有采集cann信息,不基于free+acl events进行gc分析 + if acl_events and large_free_events: + free_event = self.get_free_events_include_gc(large_free_events, acl_events) + if not free_event: + return + self.desc = self.gc_problem_with_free.format(free_duration_time=free_event.dur) + + return + + for gc_event in event_dataset.gc_events: + if convert_to_float(gc_event.dur) >= self.gc_threshold: + self.gc_issues = True + self.abnormal_gc_count += 1 + self.abnormal_gc_duration += convert_to_float(gc_event.dur) + self.abnormal_gc_list.append([gc_event.ts, gc_event.dur]) + self.abnormal_gc_duration = round(self.abnormal_gc_duration / 1000, 4) + self.abnormal_gc_list.sort(key=lambda x: x[1], reverse=True) + self.desc = self.gc_problem_with_count.format(gc_count=self.abnormal_gc_count, + gc_total_time=self.abnormal_gc_duration) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.gc_issues: + return + + self.optimization_item.append(OptimizeItem("GC", self.desc, self.suggestions)) + for optimization in self.optimization_item: + result.add(OptimizeRecord(optimization)) + if self.rank is not None: + self.headers = ["Rank id"] + self.headers + sub_table_name = "GcAnalysis" if not self.stage else f"Stage-{self.stage}: GcAnalysis" + result.add_detail(sub_table_name, headers=self.headers) + + for row in self.abnormal_gc_list: + if self.rank is not None: + row = [self.rank] + row + result.add_detail(sub_table_name, detail=row) + + def make_render(self, html_render, **kwargs): + if not self.gc_issues: + return + priority = kwargs.get("priority") + rank = kwargs.get("rank") + show_num = min(self.gc_topk_num, self.abnormal_gc_count) + html_render.render_template(key="schedule", + template_dir="templates", + template_name="gc.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + datas=self.abnormal_gc_list[:show_num], + num=show_num, + priority_background_color=priority, + rank=rank) + + def get_free_events_include_gc(self, large_free_events, acl_events): + free_event_index, acl_event_index = 0, 0 + free_include_acl_events = {} + + while free_event_index < len(large_free_events) and acl_event_index < len(acl_events): + free_event = large_free_events[free_event_index] + free_event_name = f"{const.FREE}-{free_event_index}" + free_event_start_time = convert_to_float(free_event.ts) + free_event_end_time = free_event_start_time + convert_to_float(free_event.dur) + if free_event_name not in free_include_acl_events: + free_include_acl_events[free_event_name] = {} + + while acl_event_index < len(acl_events): + acl_event = acl_events[acl_event_index] + acl_event_start_time = convert_to_float(acl_event.ts) + acl_event_end_time = acl_event_start_time + convert_to_float(acl_event.dur) + + if acl_event_end_time < free_event_start_time: + acl_event_index += 1 + continue + if acl_event_start_time > free_event_end_time: + break + + if "acl_event_count" not in free_include_acl_events[free_event_name]: + free_include_acl_events[free_event_name]["acl_event_count"] = 0 + free_include_acl_events[free_event_name]["acl_event_count"] += 1 + + if "acl_event_dur" not in free_include_acl_events[free_event_name]: + free_include_acl_events[free_event_name]["acl_event_dur"] = 0.0 + free_include_acl_events[free_event_name]["acl_event_dur"] += convert_to_float(acl_event.dur) + + acl_event_index += 1 + + free_event_index += 1 + + # 按free持续时间降序排列,优先判断持续时间最长的free + event_indexs = range(len(large_free_events)) + for index, free_event in sorted(zip(event_indexs, large_free_events), key=lambda x: x[1].dur, reverse=True): + + free_event_name = f"{const.FREE}-{index}" + free_duration = convert_to_float(free_event.dur) + acl_event_dur = free_include_acl_events.get(free_event_name, {}).get("acl_event_dur", 0.0) + acl_event_count = free_include_acl_events.get(free_event_name, {}).get("acl_event_count", 0) + if safe_division(acl_event_dur, free_duration) < self.max_acl_event_time_ratio and safe_division( + acl_event_count, free_duration) < self.max_acl_event_num_ratio: + self.gc_issues = True + return free_event + return {} + + def _init_rule(self): + gc_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + "rules", + "gc.yaml" + ) + + gc_rule = FileManager.read_yaml_file(gc_rule_path) + + self.gc_threshold = convert_to_float(gc_rule.get("gc_threshold", 0)) + self.gc_topk_num = convert_to_int(gc_rule.get("top_num", 0)) + self.gc_problem_with_count = gc_rule.get("gc_problem_with_count", "") + self.gc_problem_with_free = gc_rule.get("gc_problem_with_free", "") + self.max_acl_event_num_ratio = convert_to_float(gc_rule.get("max_acl_event_num_ratio")) + self.max_acl_event_time_ratio = convert_to_float(gc_rule.get("max_acl_event_time_ratio")) + + self.solutions = gc_rule.get("solutions", []) + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/schedule/syncbn/__init__.py b/profiler/advisor/analyzer/schedule/syncbn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py b/profiler/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..720b07e9429fb934590b40e4ec257bfd563a657b --- /dev/null +++ b/profiler/advisor/analyzer/schedule/syncbn/syncbn_analyzer.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.schedule.syncbn.syncbn_checker import SyncBNChecker +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset + +logger = logging.getLogger() + + +class SyncBNAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.result = OptimizeResult() + self.html_render = HTMLRender() + key = ScheduleAnalysisDataset.get_key() + self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + syncbn_checker = SyncBNChecker() + syncbn_checker.check_syncbn(self.timeline_event_dataset) + syncbn_checker.make_record(self.result) + syncbn_checker.make_render(self.html_render, priority=self.get_priority(), rank=kwargs.get("rank")) + return self.result + + def get_priority(self): + return PriorityBackgroundColor.high \ No newline at end of file diff --git a/profiler/advisor/analyzer/schedule/syncbn/syncbn_checker.py b/profiler/advisor/analyzer/schedule/syncbn/syncbn_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..4a288a0008ccafbc460821f5a1f981ae99989d22 --- /dev/null +++ b/profiler/advisor/analyzer/schedule/syncbn/syncbn_checker.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +class SyncBNChecker: + + def __init__(self): + self.optimization_item = [] + self.syncbn_issues = False + self.desc = "" + self.suggestions = [] + self.solutions = None + self.max_syncbn_num = None + self._init_rule() + + def check_syncbn(self, event_dataset: ScheduleAnalysisDataset): + """ + :Param event_dataset: dataset of timeline event + """ + if not hasattr(event_dataset, "sync_batchnorm") or not getattr(event_dataset, "sync_batchnorm"): + logger.debug("Skip syncbn checker, because no syncbn found") + return + + syncbn_num = len(event_dataset.sync_batchnorm) + self.syncbn_issues = syncbn_num >= self.max_syncbn_num + self.desc = self.desc.format(syncbn_num=syncbn_num) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.syncbn_issues: + return + + self.optimization_item.append(OptimizeItem("SyncBatchNorm", self.desc, self.suggestions)) + for optimization in self.optimization_item: + result.add(OptimizeRecord(optimization)) + + def make_render(self, html_render, **kwargs): + if not self.syncbn_issues: + return + + priority = kwargs.get("priority") + rank = kwargs.get("rank") + html_render.render_template(key="schedule", + template_dir="templates", + template_name="sync_batchnorm.html", + desc=self.desc, + solutions=self.solutions, + priority_background_color=priority, + rank=rank) + + def _init_rule(self): + syncbn_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + "rules", + "sync_batchnorm.yaml" + ) + + syncbn_rule = FileManager.read_yaml_file(syncbn_rule_path) + + self.max_syncbn_num = syncbn_rule.get("max_syncbn_num") + self.desc = syncbn_rule.get("problem") + + self.solutions = syncbn_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/schedule/synchronize_stream/__init__.py b/profiler/advisor/analyzer/schedule/synchronize_stream/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py b/profiler/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..45d7132e3deea27dcffbe84f06ecc5ba4654d9d9 --- /dev/null +++ b/profiler/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_analyzer.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.schedule.synchronize_stream.synchronize_stream_checker import SynchronizeStreamChecker +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset + +logger = logging.getLogger() + + +class SynchronizeStreamAnalyzer(BaseAnalyzer): + dataset_cls_list = [ScheduleAnalysisDataset] + + def __init__(self, collection_path, **kwargs): + super().__init__(collection_path, **kwargs) + self.result = OptimizeResult() + self.html_render = HTMLRender() + + key = ScheduleAnalysisDataset.get_key() + self.timeline_event_dataset = self.get_first_data_by_key(self.dataset_list, key) + + @BaseAnalyzer.check_data((ScheduleAnalysisDataset.get_key(),)) + def optimize(self, **kwargs): + synchronize_stream_checker = SynchronizeStreamChecker() + synchronize_stream_checker.check_synchronize(self.timeline_event_dataset) + synchronize_stream_checker.make_record(self.result) + synchronize_stream_checker.make_render(self.html_render, priority=self.get_priority(synchronize_stream_checker), + rank=kwargs.get("rank")) + return self.result + + def get_priority(self, synchronize_stream_checker): + return synchronize_stream_checker.priority diff --git a/profiler/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_checker.py b/profiler/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..5a793434a2d45b6ee81c916df17416cb7caaeb54 --- /dev/null +++ b/profiler/advisor/analyzer/schedule/synchronize_stream/synchronize_stream_checker.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + +from profiler.advisor.analyzer.schedule.timeline_base_checker import TimelineBaseChecker +from profiler.advisor.common import constant as const +from profiler.advisor.config.config import Config +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.advisor.utils.utils import format_timeline_result, safe_division +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +class SynchronizeStreamChecker(TimelineBaseChecker): + + def __init__(self): + super().__init__(n_processes=1) + self.optimization_item = [] + self.synchronize_issues = False + self.desc = "" + self.suggestions = [] + self.solutions = [] + self.min_co_occurrence_ratio = 0 + self.priority = None + self._init_rule() + + def check_synchronize(self, event_dataset: ScheduleAnalysisDataset): + if not hasattr(event_dataset, "synchronize_stream") or not getattr(event_dataset, "synchronize_stream"): + logger.info("Skip synchronize stream checker, because no synchronize stream found") + return + + node_launch_num = 0 + co_occurrence_num = 0 + synchronize_num = 0 + synchronize_stream = event_dataset.synchronize_stream + for index, op in enumerate(synchronize_stream): + if op.name.startswith(const.NODE_LAUNCH): + node_launch_num += 1 + if op.name.startswith(const.SYNC_STREAM): + synchronize_num += 1 + + # 统计nodeLaunch 和 synchronizeStream 一前一后连续出现次数 + if index > 0 and synchronize_stream[index - 1].name.startswith(const.NODE_LAUNCH): + co_occurrence_num += 1 + + # 当共现次数很多时,则大概率设置了ASCEND_LAUNCH_BLOCKING环境变量 + co_occurrence_ratio = round(safe_division(co_occurrence_num, node_launch_num), 4) + if co_occurrence_ratio > self.min_co_occurrence_ratio: + self.synchronize_issues = True + + self.priority = self.get_priority() + + self.desc = self.desc.format(synchronize_num=synchronize_num, + node_launch_num=node_launch_num, + co_occur_ratio=co_occurrence_ratio) + + solutions = [] + for solution in solutions: + renderer_solution = {} + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") + renderer_solution.update({key: val}) + self.solutions.append(renderer_solution) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + if not self.synchronize_issues: + return + + self.optimization_item.append(OptimizeItem("SynchronizeStream", self.desc, self.suggestions)) + for optimization in self.optimization_item: + result.add(OptimizeRecord(optimization)) + + def make_render(self, html_render, **kwargs): + if not self.synchronize_issues: + return + priority = kwargs.get("priority") + rank = kwargs.get("rank") + format_result_for_html = format_timeline_result(dict(self.matched_op_stacks), dump_html=True) + html_render.render_template(key="schedule", + template_dir="templates", + template_name="synchronize_stream.html", + desc=self.desc, + solutions=self.solutions, + result=format_result_for_html, + with_stack_doc_url=Config().timeline_with_stack_doc_url, + empty_stacks=self.empty_stacks, + framework_black_list=self.framework_black_list, + priority_background_color=priority, + rank=rank) + + def get_priority(self): + return PriorityBackgroundColor.high + + def _init_rule(self): + synchronize_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + "rules", + "synchronize.yaml" + ) + + synchronize_rule = FileManager.read_yaml_file(synchronize_rule_path) + + self.min_co_occurrence_ratio = synchronize_rule.get("min_co_occurrence_ratio") + self.desc = synchronize_rule.get("problem") + + self.solutions = synchronize_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/schedule/timeline_base_checker.py b/profiler/advisor/analyzer/schedule/timeline_base_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef492c1a24a4efafcee86d30f2201adb40264a6 --- /dev/null +++ b/profiler/advisor/analyzer/schedule/timeline_base_checker.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +import multiprocessing +import logging + +from profiler.advisor.common import constant as const +from profiler.advisor.common.timeline.event import TimelineEvent +from profiler.advisor.dataset.timeline_event_dataset import ScheduleAnalysisDataset +from profiler.advisor.result.result import OptimizeResult + +logger = logging.getLogger() + + +class TimelineBaseChecker(ABC): + + def __init__(self, n_processes: int = 1): + self.n_processes = n_processes + self._matched_op_index = {} if self.n_processes <= 1 else multiprocessing.Manager().dict() + self.matched_op_stacks = {} + self.empty_stacks = True + self.framework_black_list = False + + def query_stack(self, event_dataset: ScheduleAnalysisDataset = None, profiling_with_stack: str = None): + if all([len(matched_index) == 0 for matched_index in self._matched_op_index.values()]): + return + + event_dataset = event_dataset if not profiling_with_stack else ScheduleAnalysisDataset( + collection_path=profiling_with_stack, data={}, _datasets={}, analysis_mode="fusion_ops", + build_dataset=False) + + op_stack_list = event_dataset.parse_data_with_generator(self._query_stack_by_matched_index) + for op_stack in op_stack_list: + for op, stack in op_stack.items(): + if op not in self.matched_op_stacks: + self.matched_op_stacks[op] = {} + if stack == const.TIMELINE_FUSION_OPS_NO_STACK_FLAG: + continue + if stack not in self.matched_op_stacks[op]: + self.matched_op_stacks[op][stack] = 0 + self.matched_op_stacks[op][stack] += 1 + + def _query_stack_by_matched_index(self, index, event): + stack_record = {} + event = TimelineEvent(event) + + matched_ops = [] + for op, matched_index in self._matched_op_index.items(): + if index not in matched_index: + continue + + matched_ops.append(op) + stack = event.args.get(const.CALL_STACKS) + + if not stack: + logger.debug("Got empty '%s' for event %s", const.CALL_STACKS, event) + continue + + if not self._is_keep_stack(stack): + self.framework_black_list = True + logger.debug("Drop stack from framework %s", const.FRAMEWORK_STACK_BLACK_LIST) + continue + + if self.empty_stacks and stack: + self.empty_stacks = False + + stack_record[op] = stack + + if matched_ops and not stack_record: + for op in matched_ops: + stack_record[op] = const.TIMELINE_FUSION_OPS_NO_STACK_FLAG + + return stack_record + + def _is_keep_stack(self, stack): + # 过滤掉torch, torch_npu, megatron, deepspeed等框架下的堆栈,这些源码基本是不能被修改的 + stack_list = stack.replace("\\r\\n", ";").split(";") + if not stack_list: + return False + + final_called_stack = stack_list[0] + for framework in const.FRAMEWORK_STACK_BLACK_LIST: + if framework in final_called_stack.split("/"): + return False + return True diff --git a/profiler/advisor/common/analyzer_scopes.py b/profiler/advisor/common/analyzer_scopes.py index 592f9d421e2bfad53a9ea621d951ae0166221623..a07a6d5de72c01c7ea568599de917d66a0a89f70 100644 --- a/profiler/advisor/common/analyzer_scopes.py +++ b/profiler/advisor/common/analyzer_scopes.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. class SupportedScopes: # used for specify fourth-level commands and define the key of the result dict @@ -6,9 +20,21 @@ class SupportedScopes: GRAPH = "graph" SLOW_RANK = "slow_rank" SLOW_LINK = "slow_link" + COMMUNICATION_RETRANSMISSION_DETECTION = "communication_retransmission_analysis" + PACKET = "packet_analysis" + BANDWIDTH_CONTENTION_DETECTION = "bandwidth_contention_analysis" OVER_ALL = "over_all" + ENVIRONMENT_VARIABLE_ANALYSIS = "environment_variable_analysis" DYNAMIC_SHAPE_ANALYSIS = "dynamic_shape_analysis" AICPU_ANALYSIS = "aicpu_analysis" BLOCK_DIM_ANALYSIS = "block_dim_analysis" OPERATOR_NO_BOUND_ANALYSIS = "operator_no_bound_analysis" TIMELINE_OP_DISPATCH = "timeline_op_dispatch" + DATALOADER = "dataloader" + SYNCBN = "syncbn" + SYNCHRONIZE_STREAM = "synchronize_stream" + FREQ_ANALYSIS = "freq_analysis" + MEMORY = "memory" + STAGE_COMPUTE = "stage_compute" + GC_ANALYSIS = "gc_analysis" + COMPARISON = "comparison" diff --git a/profiler/advisor/common/async_analysis_status.py b/profiler/advisor/common/async_analysis_status.py new file mode 100644 index 0000000000000000000000000000000000000000..4747d24520dd52c8673819c30f8e0aa913872083 --- /dev/null +++ b/profiler/advisor/common/async_analysis_status.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +class AsyncAnalysisStatus: + FAILED = "failed" + SUCCESS = "success" + ANALYZING = "analyzing" + + BAD_REQUEST_STATUS_CODE = 400 + NOT_FOUND_STATUS_CODE = 404 + INNER_ERROR_STATUS_CODE = 500 + NON_FAILED_STATUS_CODE = 200 diff --git a/profiler/advisor/common/constant.py b/profiler/advisor/common/constant.py index 697430ee6cabad8c055176a3368a8b4a25e977ab..dcaffee83df77ad192543a06acf40914a78a2c03 100644 --- a/profiler/advisor/common/constant.py +++ b/profiler/advisor/common/constant.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import stat + # timeline DEQUEUE = "Dequeue" DEQUEUE_SEP = "@" @@ -24,19 +27,23 @@ OPTIMIZER_SEP = "#" OPTIMIZER_STEP = "step" ENQUEUE = "enqueue" TORCH_TO_NPU = "torch_to_npu" +FREE = "free" OP_COMPILE_NAME = "AscendCL@aclopCompileAndExecute" OP_COMPILE_ID = "aclopCompileAndExecute" +SYNC_STREAM = "AscendCL@aclrtSynchronizeStream" +NODE_LAUNCH = "Node@launch" MAX_OP_COMPILE_NUM = 20 ACL_TO_NPU = "acl_to_npu" TASK_TYPE = "Task Type" CPU_OP = "cpu_op" AI_CORE = "AI_CORE" AI_CPU = "AI_CPU" +MIX_AIC = "MIX_AIC" CALL_STACKS = "Call stack" INPUT_DIMS = "Input Dims" OP_SEP = "-" -MA_ADVISOR_MAX_PROCESSES = 16 -MA_ADVISOR_ANALYZE_PROCESSES = "MA_ADVISOR_ANALYZE_PROCESSES" +ADVISOR_MAX_PROCESSES = 8 +ADVISOR_ANALYZE_PROCESSES = "ADVISOR_ANALYZE_PROCESSES" TIMELINE_OP_STACKS_DATASET = "timeline_op_stacks_dataset" TIMELINE_BACKWARD_NO_STACK = "Backward broadcast, without call stacks in profiling." TIMELINE_ACL_TO_NPU_NO_STACK = "Incoming flow is 'acl_to_npu', without call stacks in profiling." @@ -47,18 +54,7 @@ NO_STACK_REASON_MAP = { TIMELINE_BACKWARD_NO_STACK_CODE: "Backward broadcast, without call stacks in profiling.", TIMELINE_ACL_TO_NPU_NO_STACK_CODE: "Incoming flow is 'acl_to_npu', without call stacks in profiling." } -TIMELINE_API_DOC_URL = "https://gitee.com/ascend/mstt/blob/master/profiler/advisor/doc \ - /Samples%20of%20Fused%20Operator%20API%20Replacement.md" AFFINITY_TRAINING_API = "Affinity training api" -TIMELINE_WITH_STACK_DOC_URL = "https://www.hiascend.com/document/detail/zh/canncommercial/" \ - "70RC1/modeldevpt/ptmigr/AImpug_0067.html" -PyTorch_AOE_OPERATOR_TUNE_URL = "https://www.hiascend.com/document/detail/zh/canncommercial/" \ - "70RC1/devtools/auxiliarydevtool/aoe_16_045.html" -MSLite_Infer_AOE_OPEATOR_TUNE_URL = "https://www.mindspore.cn/lite/docs/en/master/use/cloud_infer/converter_tool_ascend.html#aoe-auto-tuning" -ENABLE_COMPILED_TUNE_URL = "https://www.hiascend.com/document/detail/zh/canncommercial/" \ - "70RC1/modeldevpt/ptmigr/AImpug_0059.html" - -ASCEND_PROFILER_URL = "https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0067.html" TIMELINE_EMPTY_STACKS_PROMPT = "These APIs have no code stack. If parameter 'with_stack=False' while profiling, " \ "please refer to {timeline_profiling_doc_url} to set 'with_stack=True'. " \ "Otherwise, ignore following affinity APIs due to backward broadcast lack of stack." @@ -66,23 +62,12 @@ TIMELINE_EMPTY_STACKS_PROMPT = "These APIs have no code stack. If parameter 'wit CLUSTER_ANALYSIS = "Cluster analysis" SLOW_RANK_TIME_RATIO_THRESHOLD = 0.05 -# version_control -CANN_VERSION_C30 = '6.3.RC2' -CANN_VERSION_C13 = '7.0.RC1' -CANN_VERSION_C15 = '7.0.0' -CANN_VERSION_C17 = '8.0.RC1' -SUPPORTED_CANN_VERSION = [CANN_VERSION_C30, CANN_VERSION_C13, CANN_VERSION_C15, CANN_VERSION_C17] -DEFAULT_CANN_VERSION = CANN_VERSION_C17 -ASCEND_PYTORCH_PROFILER = "ascend_pytorch_profiler" -MSLITE = "mslite" -MSPROF = "msprof" -SUPPORTED_PROFILING_TYPE = [ASCEND_PYTORCH_PROFILER, MSLITE, MSPROF] -DEFAULT_PROFILING_TYPE = ASCEND_PYTORCH_PROFILER -TORCH_VERSION_1_11_0 = '1.11.0' -TORCH_VERSION_2_1_0 = '2.1.0' - -SUPPORTED_TORCH_VERSION = [TORCH_VERSION_1_11_0, TORCH_VERSION_2_1_0] -DEFAULT_TORCH_VERSION = TORCH_VERSION_2_1_0 +CANN_VERSION = "cann_version" +TORCH_VERSION = "torch_version" +PROFILING_TYPE = "profiling_type" +ANALYSIS_DIMENSIONS = "analysis_dimensions" + +PROFILER_METADATA = "profiler_metadata.json" TERMINAL_OUTPUT_HEADERS = ["No.", "Problem", "Description", "Suggestion"] SKIP_ANALYZE_PROMPT = "Finish analysis, no optimization suggestions" @@ -111,7 +96,7 @@ HTTP_PREFIXES = "http://" HTTPS_PREFIXES = "https://" COMMON_YAML_DIR = "modelarts/solution/ma_advisor_rules/" COMMON_ENDPOINT_SUFFIX = "obs.{}.myhuaweicloud.com" -INNER_ENDPOINT_SUFFIX= "obs.{}.ulanqab.huawei.com" +INNER_ENDPOINT_SUFFIX = "obs.{}.ulanqab.huawei.com" AICPU_RULES_YAML_NAME = "aicpu_rules.yaml" FUSION_PASS_YAML_NAME = "op_fusion_pass.yaml" @@ -120,6 +105,7 @@ CLOUD_YAML_NAME_LIST = [AICPU_RULES_YAML_NAME, FUSION_PASS_YAML_NAME, TIMELINE_F MAX_RETRIES = 3 TIMEOUT = 3 +DEPTH_LIMIT = 20 ADVISOR_RULE_PATH = "ADVISOR_RULE_PATH" CLOUD_RULE_PATH = "rules/cloud/" @@ -136,6 +122,33 @@ CLUSTER_ANALYSIS_OUTPUT = "cluster_analysis_output" KERNEL_DETAILS_CSV = "kernel_details.csv" CLUSTER_STEP_TIME_CSV = "cluster_step_trace_time.csv" CLUSTER_COMM_JSON = "cluster_communication.json" +COMMUNICATION_JSON = "communication.json" BOTTLENECK = "bottleneck" -DATA = "data" \ No newline at end of file +DATA = "data" +ADVISOR_ANALYSIS_OUTPUT_DIR = "advisor_analysis_result" +DEFAULT_PROCESSES = 8 +CLUSTER_ANALYSIS_FILE_PATTERN = [ + r'profiler_info_\d+\.json', "step_trace_time.csv", "communication.json", "communication_matrix.json" +] +ANALYSIS_OUTPUT_PATH = "ANALYSIS_OUTPUT_PATH" +DEFAULT_RANK_FOR_PROFILING_ANALYSIS = 0 +PROFILER_INFO_FILE_PATTERN = r"profiler_info_(\d+)\.json" +DISABLE_STREAMINIG_READER = "DISABLE_STREAMINIG_READER" +FRAMEWORK_STACK_BLACK_LIST = ["torch", "torch_npu", "megatron", "deepspeed"] +DISABLE_STREAMING_READER = "DISABLE_STREAMING_READER" +MAX_FILE_SIZE = 10 ** 10 +MAX_NUM_PROCESSES = 4 +DEFAULT_STEP = "-1" +STEP_RANK_SEP = "_" + +MAX_READ_LINE_BYTES = 8196 * 1024 +MAX_READ_FILE_BYTES = 64 * 1024 * 1024 * 1024 +MAX_READ_DB_FILE_BYTES = 8 * 1024 * 1024 * 1024 + +WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP +WRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + +DISABLE_PROFILING_COMPARISON = "DISABLE_PROFILING_COMPARISON" +FREE_DURATION_FOR_GC_ANALYSIS = "FREE_DURATION_FOR_GC_ANALYSIS" +DISABLE_AFFINITY_API = "DISABLE_AFFINITY_API" diff --git a/profiler/advisor/common/enum_params_parser.py b/profiler/advisor/common/enum_params_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..49787ad6334859a93509f3b37b67a703a0579f49 --- /dev/null +++ b/profiler/advisor/common/enum_params_parser.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import logging +import typing + +from profiler.advisor.common.timeline.event import AdvisorDict +from profiler.advisor.utils.utils import singleton +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +@singleton +class EnumParamsParser(): + # 枚举变量抽象成yaml文件,统一管理,便于第三方服务对接advisor时调用当前类查询所有枚举变量参数的默认值和可选值 + + ARGUMENTS = "arguments" + ENVS = "envs" + OPTIONS = "options" + DEFAULT = "default" + TYPE = "type" + STR_TYPE = "str" + LIST_TYPE = "list" + INT_TYPE = "int" + BOOLEAN_TYPE = "boolean" + + def __init__(self): + enum_params_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "config", + "enum_parameters.yaml") + self.enum_params = FileManager.read_yaml_file(enum_params_path) + self._set_value() + + def get_keys(self): + return list(self.get_arguments_keys()) + list(self.get_envs_keys()) + + def get_arguments_keys(self): + return list(self.enum_params.get(self.ARGUMENTS, {}).keys()) + + def get_envs_keys(self): + return list(self.enum_params.get(self.ENVS, {}).keys()) + + def get_options(self, key, filter_func=None): + options = [] + for param_type in [self.ARGUMENTS, self.ENVS]: + if key not in self.enum_params.get(param_type, {}): + continue + options = self.enum_params.get(param_type, {}).get(key, {}).get(self.OPTIONS, []) + + if not options: + logger.error("Key %s not exists, optionals are %s", key, self.get_keys()) + + if filter_func is not None and callable(filter_func): + options = [value for value in options if filter_func(value)] + + return options + + def get_value_type(self, key): + for param_type in [self.ARGUMENTS, self.ENVS]: + if key not in self.enum_params.get(param_type, {}): + continue + value_type = self.enum_params.get(param_type, {}).get(key, {}).get(self.TYPE, self.STR_TYPE) + return value_type + return self.STR_TYPE + + def get_default(self, key): + default_value = None + for param_type in [self.ARGUMENTS, self.ENVS]: + if key not in self.enum_params.get(param_type, {}): + continue + default_value = self.enum_params.get(param_type, {}).get(key, {}).get(self.DEFAULT, []) + + if not default_value: + logger.error("Key %s not exists, optionals are %s", key, self.get_keys()) + + return default_value + + def _set_value(self): + + for key in self.get_keys(): + + if not hasattr(self, key): + setattr(self, str(key), AdvisorDict()) + + options = self.get_options(key) + + for value in options: + if not isinstance(value, typing.Hashable): + continue + getattr(self, key)[str(value)] = value diff --git a/profiler/advisor/common/graph/graph.py b/profiler/advisor/common/graph/graph.py index 6bab2042de3a09f9317f71fc6a5c9740743cc790..f86f5db7f2cec8f51d3daab6b9c6e9d22de44483 100644 --- a/profiler/advisor/common/graph/graph.py +++ b/profiler/advisor/common/graph/graph.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" import logging from typing import Dict, List, Tuple, Callable, Any, Optional, Union @@ -25,7 +41,7 @@ class Graph: self.edges = edges if edges is not None else list() def build(self): - for op_name, node in self.nodes.items(): + for _, node in self.nodes.items(): # add node and mark op_name as tag self.add_node(node, op_type=node.op_type @@ -72,7 +88,7 @@ class Graph: if pre_node is None or next_node is None: raise ValueError(f"Invalid edge from {pre_node} to {pre_node}.") - self.remove_edge(pre_node, next_node) + self.graph.remove_edge(pre_node, next_node) def get_subgraph(self, nodes: List[HostGraphNode]) -> nx.DiGraph: nodes = list(set(nodes)) @@ -86,50 +102,10 @@ class Graph: pass def get_node(self, node: HostGraphNode): - if node not in self.graph: - return - - return self.graph[node] + return self.graph[node] if node in self.graph else None def get_node_by_name(self, node_name: str): return self.nodes.get(node_name, None) def is_node_exists(self, node: HostGraphNode): return node in self.graph - - def draw(self, - graph: nx.DiGraph = None, - with_labels: bool = False, - labels: Dict[HostGraphNode, Any] = None, - pos_func: Callable = None, - font_weight: str = "bold", - savefig: bool = False, - node_size: int = 50, - **kwargs - ): - try: - import matplotlib.pylab as plt - except ImportError: - logger.error('Please install matplotlib first by using `pip install matplotlib`.') - return - - if graph is None: - graph = self.graph - - pos = pos_func(graph) if pos_func is not None else None - - if with_labels: - if labels is None: - labels = {k: f"{k}\n({v['op_name']})" for k, v in graph.nodes.items()} - - nx.draw(graph, - with_labels=with_labels, - pos=pos, - node_size=node_size, - font_weight=font_weight, - labels=labels, - **kwargs - ) - if savefig: - plt.savefig(self.name + ".png") - plt.show() diff --git a/profiler/advisor/common/graph/graph_match.py b/profiler/advisor/common/graph/graph_match.py index d0dfc162952b0c52bf9ed73cef2ff18ff5ffda24..fbf0a8abe8e049ccb6f9ff2baaa528e94cb3d7e2 100644 --- a/profiler/advisor/common/graph/graph_match.py +++ b/profiler/advisor/common/graph/graph_match.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" import itertools import logging from functools import lru_cache @@ -7,6 +23,47 @@ from typing import Dict, Generator, List, Callable, Hashable, Tuple import networkx as nx +class IsomorphismsIterArgsConfig: + def __init__(self, + query_graph: nx.Graph, + host_graph: nx.Graph, + *args, + directed: bool = None, + _node_attr_fun: Callable = None, + _node_struct_fun: Callable = None, + _edge_attr_fun: Callable = None, + **kwargs + ): + self.query_graph = query_graph + self.host_graph = host_graph + self.directed = directed + self.node_attr_fun = _node_attr_fun + self.node_struct_fun = _node_struct_fun + self.edge_attr_fun = _edge_attr_fun + self.args = args + self.kwargs = kwargs + + +class CandidateArgsConfig: + def __init__(self, + backbone: Dict, + query_graph: nx.Graph, + host_graph: nx.Graph, + next_node: Hashable = None, + directed: bool = True, + _node_attr_fun: Callable = None, + _node_struct_fun: Callable = None, + _edge_attr_fun: Callable = None): + self.backbone = backbone + self.query_graph = query_graph + self.host_graph = host_graph + self.next_node = next_node + self.directed = directed + self.node_attr_fun = _node_attr_fun + self.node_struct_fun = _node_struct_fun + self.edge_attr_fun = _edge_attr_fun + + @lru_cache() def match_node_attr_fun(query_node: Hashable, host_node: Hashable, @@ -119,7 +176,7 @@ def find_isomorphisms(query_graph: nx.Graph, ``` """ candidates = [] - for query_result in find_isomorphisms_iter( + for query_result in find_isomorphisms_iter(IsomorphismsIterArgsConfig( query_graph, host_graph, *args, @@ -127,32 +184,34 @@ def find_isomorphisms(query_graph: nx.Graph, _node_struct_fun=_node_struct_fun, _edge_attr_fun=_edge_attr_fun, **kwargs - ): + )): candidates.append(query_result) if limit and len(candidates) >= limit: return candidates return candidates -def find_isomorphisms_iter(query_graph: nx.Graph, - host_graph: nx.Graph, - directed: bool = None, - _node_attr_fun: Callable = None, - _node_struct_fun: Callable = None, - _edge_attr_fun: Callable = None, - ) -> Generator[Dict[Hashable, Hashable], None, None]: +def find_isomorphisms_iter(config: IsomorphismsIterArgsConfig) -> Generator[Dict[Hashable, Hashable], None, None]: """ A generation to find one isomorphic subgraph in host_graph for query_graph. - :param query_graph: The graph object to query - :param host_graph: The graph object to be queried - :param directed: Whether direction should be considered during search - :param _node_attr_fun: The function to match node attr - :param _node_struct_fun: The function to match node structural - :param _edge_attr_fun: The function to match edge attr - :return: Yield mappings from query node IDs to host graph IDs: {query_id: host_id, ...} + :param config: An instance of IsomorphismsIterArgsConfig containing the following attributes: + - query_graph: The graph object to query + - host_graph: The graph object to be queried + - directed: Whether direction should be considered during search + - node_attr_fun: The function to match node attr + - node_struct_fun: The function to match node structural + - edge_attr_fun: The function to match edge attr + :return: Yield mappings from query node IDs to host graph IDs: {query_id: host_id, ...} """ + query_graph: nx.Graph = config.query_graph + host_graph: nx.Graph = config.host_graph + directed: bool = config.directed + _node_attr_fun: Callable = config.node_attr_fun + _node_struct_fun: Callable = config.node_struct_fun + _edge_attr_fun: Callable = config.edge_attr_fun + if directed is None: # query graph and host graph should consider directions. if isinstance(query_graph, nx.DiGraph) and \ @@ -167,14 +226,15 @@ def find_isomorphisms_iter(query_graph: nx.Graph, while len(dq) > 0: backbone = dq.pop() - next_candidate_backbones = get_next_candidates(backbone=backbone, - query_graph=query_graph, - host_graph=host_graph, - directed=directed, - _node_attr_fun=_node_attr_fun, - _node_struct_fun=_node_struct_fun, - _edge_attr_fun=_edge_attr_fun, - ) + next_candidate_backbones = get_next_candidates(CandidateArgsConfig( + backbone=backbone, + query_graph=query_graph, + host_graph=host_graph, + directed=directed, + _node_attr_fun=_node_attr_fun, + _node_struct_fun=_node_struct_fun, + _edge_attr_fun=_edge_attr_fun, + )) for candidate in next_candidate_backbones: # find a legal isomorphism if len(candidate) == len(query_graph): @@ -184,23 +244,27 @@ def find_isomorphisms_iter(query_graph: nx.Graph, dq.appendleft(candidate) -def get_next_candidates( - backbone: Dict, - query_graph: nx.Graph, # noqa - host_graph: nx.Graph, # noqa - next_node: Hashable = None, - directed: bool = True, # noqa - _node_attr_fun: Callable = None, # noqa - _node_struct_fun: Callable = None, # noqa - _edge_attr_fun: Callable = None # noqa -) -> List[Dict[Hashable, Hashable]]: +def get_next_candidates(config: CandidateArgsConfig) -> List[Dict[Hashable, Hashable]]: """ Get a list of candidate node assignments for the next "step" of this map. - :param backbone: Mapping of query node IDs to one set of host graph IDs - :param next_node: Optional suggestion for the next node to assign + :param config: An instance of CandidateArgsConfig containing the following attributes: + - backbone: Dict, a mapping of query node IDs to one set of host graph node IDs. + - query_graph: nx.Graph, the query graph whose nodes are being mapped. + - host_graph: nx.Graph, the host graph where query nodes are being assigned. + - next_node: Hashable, an optional suggestion for the next node to assign from the query graph. + :return: List[Dict[Hashable, Hashable]]: A new list of node mappings with one additional element mapped """ + backbone: Dict = config.backbone + query_graph: nx.Graph = config.query_graph + host_graph: nx.Graph = config.host_graph + next_node: Hashable = config.next_node + directed: bool = config.directed + _node_attr_fun: Callable = config.node_attr_fun + _node_struct_fun: Callable = config.node_struct_fun + _edge_attr_fun: Callable = config.edge_attr_fun + node_priority = {n: 1 for n in query_graph.nodes} candidate_nodes = [] diff --git a/profiler/advisor/common/graph/graph_parser.py b/profiler/advisor/common/graph/graph_parser.py index d4c67fc1918af37a837e016bd9e5b813957b1aef..a89cf738fff8b679219380e71435148c7f8aa216 100644 --- a/profiler/advisor/common/graph/graph_parser.py +++ b/profiler/advisor/common/graph/graph_parser.py @@ -1,11 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" import os import logging -import yaml import itertools from collections import deque from dataclasses import dataclass from typing import List, Tuple, Dict +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.utils.file import FileOpen + logger = logging.getLogger() @@ -89,7 +107,7 @@ class HostGraphParser: del self.graphs[0] @staticmethod - def _get_key_value( line): + def _get_key_value(line): res = line.split(':', 1) return res[0].strip(), res[1].strip().strip('"') @@ -219,9 +237,9 @@ class HostGraphParser: def _parse(self, graph_file): # pylint:disable=broad-except graph_list = [] - with open(graph_file, "r", encoding="gbk") as file: + with FileOpen(graph_file, "r") as file: try: - graph_list = self._parse_line(file, graph_list) + graph_list = self._parse_line(file.file_reader, graph_list) except Exception: logger.error( "Parse line %s of file %s failed, make sure the format is correct.", self.line_no, graph_file @@ -264,7 +282,11 @@ class HostGraphParser: if not self.graphs: self.nodes = {} return - self.nodes = {node.op_name: node for graph in self.graphs for node in graph.nodes.values()} + self.nodes = { + node.op_name: node + for graph in self.graphs + for node in graph.nodes.values() + } class QueryGraphNode: @@ -279,11 +301,12 @@ class QueryGraphNode: self._op_pass = op_pass QueryGraphNode._ID += 1 - def get_property(self, name): - """ - get property - """ - return getattr(self, name, lambda: None) + def __eq__(self, other): + return self._op_type == other._op_type and \ + self._id == other._id + + def __hash__(self): + return hash(self._op_type + str(self._id)) @property def op_type(self): @@ -301,13 +324,6 @@ class QueryGraphNode: def op_type(self, op_type): self._op_type = op_type - def __eq__(self, other): - return self._op_type == other._op_type and \ - self._id == other._id - - def __hash__(self): - return hash(self._op_type + str(self._id)) - @staticmethod def trim_string(string: str, length: int = -1): """ @@ -325,6 +341,12 @@ class QueryGraphNode: return string[:length] + def get_property(self, name): + """ + get property + """ + return getattr(self, name, lambda: None) + class QueryGraphParser: def __init__(self, rule_database_path: str): @@ -336,35 +358,6 @@ class QueryGraphParser: def fusion_rules(self): return self._fusion_rules - def load_database(self, rule_database): - if not os.path.isabs(rule_database): - rule_database = os.path.join(os.path.dirname(__file__), - "../", "../", - rule_database) - - if not os.path.exists(rule_database): - raise FileNotFoundError(f"Path {rule_database} does not exist.") - with open(rule_database, 'r') as f: - database = yaml.safe_load(f) - self.parse_yaml(database) - - def parse_yaml(self, yaml_database): - fusion_strategy_list = yaml_database.get("GraphFusion", []) - if yaml_database.get("UBFusion", []): - fusion_strategy_list.extend(yaml_database.get("UBFusion", [])) - for fusion_strategy in fusion_strategy_list: - if not isinstance(fusion_strategy, dict): - continue - (fusion_name, strategy), = fusion_strategy.items() - version = strategy.get("version", 0) - if version == 0 or version == "0": - self._fusion_rules[fusion_name] = self.build_query_graph_v0(fusion_name, - strategy.get('struct', [])) - elif version == 1 or version == "1": - self._fusion_rules[fusion_name] = self.build_query_graph_v1(fusion_name, - strategy.get('nodes', []), - strategy.get('edges', [])) - @staticmethod def build_query_graph_v0(graph_name: str, graph_struct: List[str]) -> List[Tuple]: nodes = dict() @@ -411,3 +404,32 @@ class QueryGraphParser: sub_graph = (sub_node, sub_edge, sub_graph_name,) graphs.append(sub_graph) return graphs + + def load_database(self, rule_database): + if not os.path.isabs(rule_database): + rule_database = os.path.join(os.path.dirname(__file__), + "../", "../", + rule_database) + + if not os.path.exists(rule_database): + raise FileNotFoundError(f"Path {rule_database} does not exist.") + + database = FileManager.read_yaml_file(rule_database) + self.parse_yaml(database) + + def parse_yaml(self, yaml_database): + fusion_strategy_list = yaml_database.get("GraphFusion", []) + if yaml_database.get("UBFusion", []): + fusion_strategy_list.extend(yaml_database.get("UBFusion", [])) + for fusion_strategy in fusion_strategy_list: + if not isinstance(fusion_strategy, dict): + continue + (fusion_name, strategy), = fusion_strategy.items() + version = strategy.get("version", 0) + if version == 0 or version == "0": + self._fusion_rules[fusion_name] = self.build_query_graph_v0(fusion_name, + strategy.get('struct', [])) + elif version == 1 or version == "1": + self._fusion_rules[fusion_name] = self.build_query_graph_v1(fusion_name, + strategy.get('nodes', []), + strategy.get('edges', [])) diff --git a/profiler/advisor/common/profiling/ge_info.py b/profiler/advisor/common/profiling/ge_info.py index 9996ec611a2a835bd8dffd24c3fbe7d8817ec29a..91642f967970fdf27f76754ee4bbd7f4ab4fcc50 100644 --- a/profiler/advisor/common/profiling/ge_info.py +++ b/profiler/advisor/common/profiling/ge_info.py @@ -1,14 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- """ -DB +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ import logging import os from typing import Any, List from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError from profiler.advisor.dataset.profiling.db_manager import ConnectionManager from profiler.advisor.dataset.profiling.profiling_parser import ProfilingParser +from profiler.advisor.utils.utils import check_path_valid logger = logging.getLogger() @@ -17,24 +32,30 @@ class GeInfo(ProfilingParser): """ ge info file """ - FILE_PATTERN = r"ge_info.db" FILE_PATTERN_MSG = "ge_info.db" FILE_INFO = "ge info" STATIC_OP_STATE = "0" DYNAMIC_OP_STATE = "1" + file_pattern_list = [r"ge_info.db"] + def __init__(self, path: str) -> None: super().__init__(path) self.op_state_info_list = None - def parse_from_file(self, profiling_db_file): + def parse_from_file(self, file: str): """ ge info """ - db_path, db_file = os.path.split(profiling_db_file) + db_path, db_file = os.path.split(file) + check_path_valid(db_path) if not ConnectionManager.check_db_exists(db_path, [db_file]): return False - conn = ConnectionManager(db_path, db_file) + try: + conn = ConnectionManager(db_path, db_file) + except SQLAlchemyError as e: + logger.error("Database error: %s", e) + return False if conn.check_table_exists(['TaskInfo']): with conn().connect() as sql_conn: self.op_state_info_list = sql_conn.execute(text("select op_name, op_state from TaskInfo")).fetchall() diff --git a/profiler/advisor/common/profiling/msprof.py b/profiler/advisor/common/profiling/msprof.py index 9453986b8225ccad68f2135d674e3832d987fcf0..150d3f985973f2c79ccf5406c114932aef5008fd 100644 --- a/profiler/advisor/common/profiling/msprof.py +++ b/profiler/advisor/common/profiling/msprof.py @@ -1,5 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- """ -msprof +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ import logging from typing import Dict, List @@ -33,10 +46,11 @@ class Msprof(ProfilingParser): msprof """ - FILE_PATTERN = r"^msprof[_\d]+.json$" FILE_PATTERN_MSG = "msprof_*.json" FILE_INFO = "msprof" + file_pattern_list = [r"^msprof[_\d]+.json$"] + def __init__(self, path: str) -> None: super().__init__(path) self._tasks: List[TaskInfo] = [] @@ -49,6 +63,49 @@ class Msprof(ProfilingParser): self._data_process_time = 0.0 self._start_point = 0.0 + def __len__(self): + return len(self._tasks) + + @property + def step_time(self): + return self._iteration_time + self._data_process_time + + @property + def iteration_time(self): + return self._iteration_time + + @property + def iter_max_time(self): + return self._max_time + + @property + def iter_min_time(self): + return self._min_time + + @property + def data_process_time(self): + return self._data_process_time + + @property + def tasks(self): + return self._tasks + + @property + def model_id(self): + return self._model_id + + @property + def iteration_id(self): + return self._iteration_id + + @property + def process_pid(self): + return self._process_pid + + @property + def start_point(self): + return self._start_point + def parse_from_file(self, file: str): if not self._parse_json(file): return False @@ -99,46 +156,3 @@ class Msprof(ProfilingParser): self._model_id = int(task.name.split(":")[1]) elif "process_name" == task.name: self._process_pid[task.args.get("name")] = task.pid - - @property - def step_time(self): - return self._iteration_time + self._data_process_time - - @property - def iteration_time(self): - return self._iteration_time - - @property - def iter_max_time(self): - return self._max_time - - @property - def iter_min_time(self): - return self._min_time - - @property - def data_process_time(self): - return self._data_process_time - - @property - def tasks(self): - return self._tasks - - @property - def model_id(self): - return self._model_id - - @property - def iteration_id(self): - return self._iteration_id - - @property - def process_pid(self): - return self._process_pid - - def __len__(self): - return len(self._tasks) - - @property - def start_point(self): - return self._start_point diff --git a/profiler/advisor/common/profiling/op_summary.py b/profiler/advisor/common/profiling/op_summary.py index d79439dbad8e2c105bed737c1a1c3be1a2cecfc1..c042509df96c0c8feacb39a56e6f73358cd5d8a9 100644 --- a/profiler/advisor/common/profiling/op_summary.py +++ b/profiler/advisor/common/profiling/op_summary.py @@ -1,5 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- """ -summary +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ import logging from decimal import Decimal @@ -16,13 +29,13 @@ class OpSummary(ProfilingParser): """ op summary """ - - FILE_PATTERN = r"^op_summary_[_\d]+\.csv$" FILE_PATTERN_MSG = "op_summary_*.csv" FILE_INFO = "op summary" STATIC_OP_STATE = "static" DYNAMIC_OP_STATE = "dynamic" + file_pattern_list = [r"^op_summary_[_\d]+\.csv$"] + def __init__(self, path: str) -> None: super().__init__(path) self.op_list: List[OpInfo] = [] @@ -52,7 +65,8 @@ class OpSummary(ProfilingParser): return True def get_static_shape_operators(self) -> List[Any]: - return [op_info.get_attr("op_name") for op_info in self.op_list if op_info.get_attr("op_state") == self.STATIC_OP_STATE] + return [op_info.get_attr("op_name") + for op_info in self.op_list if op_info.get_attr("op_state") == self.STATIC_OP_STATE] def get_total_task_duration(self): """ diff --git a/profiler/advisor/common/profiling/tasktime.py b/profiler/advisor/common/profiling/tasktime.py index 3ce09a783851e94163aa72f423788a373da5eb3a..211800585a6b3385e41d009827ec675bfa9df560 100644 --- a/profiler/advisor/common/profiling/tasktime.py +++ b/profiler/advisor/common/profiling/tasktime.py @@ -1,5 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- """ -task time +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ import logging from typing import Dict, List @@ -17,11 +30,11 @@ class TaskTime(ProfilingParser): """ task time info """ - - FILE_PATTERN = r"^task_time_[_\d]+\.json$" FILE_PATTERN_MSG = "task_time*.json" FILE_INFO = "task time" + file_pattern_list = [r"^task_time_[_\d]+\.json$"] + def __init__(self, path: str) -> None: super().__init__(path) self._tasks: List[TaskInfo] = [] diff --git a/profiler/advisor/common/timeline/event.py b/profiler/advisor/common/timeline/event.py index 6001ac88722e5a77daba1c960e8ccfd6894889e6..79ee63211c33515ce8bad1a3a537caa65ac86511 100644 --- a/profiler/advisor/common/timeline/event.py +++ b/profiler/advisor/common/timeline/event.py @@ -1,3 +1,22 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +from decimal import Decimal + + class AdvisorDict(dict): def __getstate__(self): return self.__dict__ @@ -18,6 +37,14 @@ class AdvisorDict(dict): class TimelineEvent(AdvisorDict): def ts_include(self, event): + self_ts = self.ts + event_ts = event.ts + + if not self_ts or not event_ts: + return False + + self_dur = self.dur if not isinstance(self.dur, dict) else 0.0 + event_dur = event.dur if not isinstance(event.dur, dict) else 0.0 - return float(self.ts) <= float(event.ts) and float(self.ts) + float(self.dur) >= float(event.ts) + float( - event.dur) \ No newline at end of file + return Decimal(self_ts) <= Decimal(event_ts) and Decimal(self_ts) + Decimal(self_dur) >= Decimal( + event_ts) + Decimal(event_dur) diff --git a/profiler/advisor/common/timeline/fusion_ops_db.py b/profiler/advisor/common/timeline/fusion_ops_db.py index 8637befd1ab108928bdf8f4fdb19d9cab03ff960..ad8b5981c72b12c213146b205d1f1d86dd408589 100644 --- a/profiler/advisor/common/timeline/fusion_ops_db.py +++ b/profiler/advisor/common/timeline/fusion_ops_db.py @@ -1,13 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" import logging import os -import yaml - from profiler.advisor.common import constant +from profiler.advisor.common.enum_params_parser import EnumParamsParser from profiler.advisor.common.timeline.fusion_ops_rule import OpRule from profiler.advisor.common.timeline.fusion_ops_rule_handler import TimelineOpRuleHandler from profiler.advisor.utils.log import get_log_level from profiler.advisor.utils.utils import get_file_path_by_walk +from profiler.cluster_analyse.common_func.file_manager import FileManager logger = logging.getLogger() logger.setLevel(get_log_level()) @@ -31,7 +47,8 @@ def get_timeline_fusion_ops_yaml_path(): logger.warning("The %s does not exist in path: %s. Try to use cloud or default local YAML file.", constant.TIMELINE_FUSION_OPS_YAML_NAME, os.path.normpath(advisor_rule_path)) # 检查云文件默认保存路径文件夹下是否存在相应文件, 默认路径 ~/rules/cloud/ - cloud_file_path = os.path.join(os.path.expanduser("~"), constant.CLOUD_RULE_PATH, constant.TIMELINE_FUSION_OPS_YAML_NAME) + cloud_file_path = os.path.join(os.path.expanduser("~"), constant.CLOUD_RULE_PATH, + constant.TIMELINE_FUSION_OPS_YAML_NAME) if os.path.exists(cloud_file_path): logger.debug("Successfully find The cloud %s file in %s.", constant.TIMELINE_FUSION_OPS_YAML_NAME, cloud_file_path) @@ -51,8 +68,8 @@ class FusionOperatorDB: def __init__(self, file_path=None, cann_version=None, torch_version=None): self.timeline_fusion_ops_yaml_path = os.path.normpath(get_timeline_fusion_ops_yaml_path()) - self.cann_version = cann_version or constant.DEFAULT_CANN_VERSION - self.torch_version = torch_version or constant.DEFAULT_TORCH_VERSION + self.cann_version = cann_version or EnumParamsParser().get_default(constant.CANN_VERSION) + self.torch_version = torch_version or EnumParamsParser().get_default(constant.TORCH_VERSION) self._supported_version_dict = {} @@ -241,8 +258,7 @@ class FusionOperatorDB: logger.debug("The rule yaml file is successfully found in path: %s", os.path.abspath(file_path)) - with open(file_path, "rb") as file: - db_content = yaml.safe_load(file) + db_content = FileManager.read_yaml_file(file_path) if not self._is_version_supported(db_content): self.is_empty = True diff --git a/profiler/advisor/common/version_control.py b/profiler/advisor/common/version_control.py index 38b054543fc61e90d91e8442a547376cff4c6406..ec30b3be9d84532ff4e8829341dd2da4d3dfc49f 100644 --- a/profiler/advisor/common/version_control.py +++ b/profiler/advisor/common/version_control.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" import logging from typing import List @@ -10,8 +26,8 @@ class VersionControl: @classmethod def is_supported(cls, cann_version: str) -> bool: """ - Check whether the CANN software version is supported, which can be viewed by executing the following command: - 'cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info' + Check whether the CANN software version is supported, which can be viewed by executing the following command: + 'cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info' """ flag = (cls._SUPPORT_VERSIONS.__contains__(cann_version)) if not flag: diff --git a/profiler/advisor/config/config.ini b/profiler/advisor/config/config.ini index c56c1dad9f0d7e9ac02ab76b0e79e102b010da12..08dd8f2d95af0b15d732450093d9acc170b237d7 100644 --- a/profiler/advisor/config/config.ini +++ b/profiler/advisor/config/config.ini @@ -9,8 +9,16 @@ tune_ops_file = operator_tuning_file.cfg [THRESHOLD] # operator_bound_ratio: (mte, cube, vector, scalar) ratio greater than this value will be checked in operator_bound_checker operator_bound_ratio = 0.8 +frequency_threshold = 0.05 [RULE-BUCKET] # region : URL of different regions where can download rule yaml file cn-north-9 = cnnorth9-modelarts-sdk cn-southwest-2 = cnsouthwest2-modelarts-sdk -cn-north-7 = cnnorth7-modelarts-sdk \ No newline at end of file +cn-north-7 = cnnorth7-modelarts-sdk +[URL] +timeline_api_doc_url = https://gitee.com/ascend/mstt/blob/master/profiler/advisor/doc/Samples%20of%20Fused%20Operator%20API%20Replacement.md +timeline_with_stack_doc_url = https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0067.html +pytorch_aoe_operator_tune_url = https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/devtools/auxiliarydevtool/aoe_16_045.html +mslite_infer_aoe_operator_tune_url = https://www.mindspore.cn/lite/docs/en/master/use/cloud_infer/converter_tool_ascend.html#aoe-auto-tuning +enable_compiled_tune_url = https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0059.html +ascend_profiler_url = https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0067.html \ No newline at end of file diff --git a/profiler/advisor/config/config.py b/profiler/advisor/config/config.py index 12f4526f8c95a747f97272aed6cf8e4e822da676..2c074d6535b83b9d8e1a1a225dcfe06b6fbe9a1f 100644 --- a/profiler/advisor/config/config.py +++ b/profiler/advisor/config/config.py @@ -1,13 +1,13 @@ """ advisor config """ -from profiler.advisor.utils.utils import Timer import logging import os -from configparser import ConfigParser +from profiler.advisor.utils.utils import Timer from profiler.advisor.utils.utils import singleton +from profiler.prof_common.utils import SafeConfigReader logger = logging.getLogger() @@ -21,15 +21,27 @@ class Config: _CONFIG_DIR_NAME = "config" _CONFIG_FILE_NAME = "config.ini" + _REQUIRED_SECTIONS = { + 'LOG': ['console_logging_level'], + 'ANALYSE': ['analysis_result_file', 'tune_ops_file'], + 'THRESHOLD': ['operator_bound_ratio', 'frequency_threshold'], + 'RULE-BUCKET': ['cn-north-9', 'cn-southwest-2', 'cn-north-7'], + 'URL': [ + 'timeline_api_doc_url', 'timeline_with_stack_doc_url', + 'pytorch_aoe_operator_tune_url', 'mslite_infer_aoe_operator_tune_url', + 'enable_compiled_tune_url', 'ascend_profiler_url' + ] + } def __init__(self) -> None: - config = ConfigParser(allow_no_value=True) self._work_path = os.getcwd() # pwd self._root_path = os.path.abspath(os.path.join(__file__, "../../")) - config.read(os.path.join(self._root_path, self._CONFIG_DIR_NAME, self._CONFIG_FILE_NAME)) - self.config = config + self.config_reader = SafeConfigReader(os.path.join(self._root_path, self._CONFIG_DIR_NAME, + self._CONFIG_FILE_NAME)) + self.config_reader.validate(self._REQUIRED_SECTIONS) + self.config = self.config_reader.get_config() # ANALYSE - self._analysis_result_file = self._normalize_path(config.get("ANALYSE", "analysis_result_file")) + self._analysis_result_file = self._normalize_path(self.config.get("ANALYSE", "analysis_result_file")) self._tune_ops_file = os.path.abspath( os.path.join(self._work_path, f"operator_tuning_file_{Timer().strftime}.cfg")) self.log_path = None @@ -97,6 +109,55 @@ class Config: """ return float(self.config.get("THRESHOLD", "operator_bound_ratio")) + @property + def frequency_threshold(self) -> float: + """ + frequency_threshold + """ + return float(self.config.get("THRESHOLD", "frequency_threshold")) + + @property + def timeline_api_doc_url(self) -> str: + try: + return self.config.get("URL", "timeline_api_doc_url") + except Exception: + return "" + + @property + def timeline_with_stack_doc_url(self) -> str: + try: + return self.config.get("URL", "timeline_with_stack_doc_url") + except Exception: + return "" + + @property + def pytorch_aoe_operator_tune_url(self) -> str: + try: + return self.config.get("URL", "pytorch_aoe_operator_tune_url") + except Exception: + return "" + + @property + def mslite_infer_aoe_operator_tune_url(self) -> str: + try: + return self.config.get("URL", "mslite_infer_aoe_operator_tune_url") + except Exception: + return "" + + @property + def enable_compiled_tune_url(self) -> str: + try: + return self.config.get("URL", "enable_compiled_tune_url") + except Exception: + return "" + + @property + def ascend_profiler_url(self) -> str: + try: + return self.config.get("URL", "ascend_profiler_url") + except Exception: + return "" + def set_log_path(self, result_file: str, log_path: str = None): self.log_path = log_path if log_path is not None else os.path.join(self._work_path, "log") os.makedirs(self.log_path, exist_ok=True) @@ -106,3 +167,5 @@ class Config: def remove_log(self): if self.log_path and os.path.isdir(self.log_path) and not os.listdir(self.log_path): os.rmdir(self.log_path) + + diff --git a/profiler/advisor/config/enum_parameters.yaml b/profiler/advisor/config/enum_parameters.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebea1740a2e32ee6cdd274e75576a69f15deee6d --- /dev/null +++ b/profiler/advisor/config/enum_parameters.yaml @@ -0,0 +1,50 @@ +arguments: + cann_version: + type: str + options: + - 6.3.RC2 + - 7.0.RC1 + - 7.0.0 + - 8.0.RC1 + default: 8.0.RC1 + + torch_version: + type: str + options: + - 1.11.0 + - 2.1.0 + default: 2.1.0 + + analysis_dimensions: + type: list + options: + - [ computation, communication, schedule, memory ] + - [ computation ] + - [ communication ] + - [ schedule ] + - [ memory ] + default: [ computation, communication, schedule, memory ] + + profiling_type: + type: str + options: + - ascend_pytorch_profiler + - mslite + - msprof + default: ascend_pytorch_profiler + +envs: + ADVISOR_ANALYZE_PROCESSES: + type: int + options: [ 1, 2, 3, 4, 5, 6, 7, 8 ] + default: 1 + + DISABLE_PROFILING_COMPARISON: + type: boolean + options: [ true, false ] + default: false + + DISABLE_AFFINITY_API: + type: boolean + options: [ true, false ] + default: false diff --git a/profiler/advisor/config/profiling_data_version_config.yaml b/profiler/advisor/config/profiling_data_version_config.yaml index 4ef76105a07c28c5072c4bbfe20fd39a938038b7..b8c92fe074d3bf67a23214d18f6a2438be130314 100644 --- a/profiler/advisor/config/profiling_data_version_config.yaml +++ b/profiler/advisor/config/profiling_data_version_config.yaml @@ -1,18 +1,19 @@ versions: - version: 8.0.RC1 dirs_pattern: + ASCEND_PROFILER_OUTPUT: [ op_summary ] ^PROF_\d{6}_\d{17}_\w+$: - mindstudio_profiler_output: - [ op_summary, msprof ] + mindstudio_profiler_output: [ op_summary, msprof ] class_attr: op_summary: OpSummary msprof: Msprof file_attr: - op_summary: ^op_summary_\d{14}\.csv$ msprof: ^msprof_\d{14}\.json$ + op_summary: [ kernel_details.csv, '^op_summary_\d{14}\.csv$' ] - version: 7.0.0 dirs_pattern: + ASCEND_PROFILER_OUTPUT: [ op_summary ] ^PROF_\d{6}_\d{17}_\w+$: ^device_\d+$: summary: @@ -28,13 +29,14 @@ versions: msprof: Msprof ge_info: GeInfo file_attr: - op_summary: ^op_summary_\d+_\d+_\d{14}\.csv$ + op_summary: [ kernel_details.csv, '^op_summary_\d+_\d+_\d{14}\.csv$'] task_time: ^task_time_\d+_\d+_\d{14}\.json$ msprof: ^msprof_\d+_\d+_\d{14}\.json$ ge_info: ge_info.db - version: 7.0.RC1 dirs_pattern: + ASCEND_PROFILER_OUTPUT: [ op_summary ] ^PROF_\d{6}_\d{17}_\w+$: ^device_\d+$: summary: @@ -50,13 +52,14 @@ versions: msprof: Msprof ge_info: GeInfo file_attr: - op_summary: ^op_summary_\d+_\d+_\d+_\d{14}\.csv$ + op_summary: [ kernel_details.csv, '^op_summary_\d+_\d+_\d+_\d{14}\.csv$'] task_time: ^task_time_\d+_\d+_\d+_\d{14}\.json$ msprof: ^msprof_\d+_\d+_\d+_\d{14}\.json$ ge_info: ge_info.db - version: 6.3.RC2 dirs_pattern: + ASCEND_PROFILER_OUTPUT: [ op_summary ] ^PROF_\d{6}_\d{17}_\w+$: ^device_\d+$: summary: @@ -72,9 +75,7 @@ versions: msprof: Msprof ge_info: GeInfo file_attr: - op_summary: ^op_summary_\d+_\d+\.csv$ + op_summary: [ kernel_details.csv, '^op_summary_\d+_\d+\.csv$'] task_time: ^task_time_\d+_\d+\.json$ msprof: ^msprof_\d+_\d+\.json$ ge_info: ge_info.db - - diff --git a/profiler/advisor/dataset/ai_core_freq/__init__.py b/profiler/advisor/dataset/ai_core_freq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/dataset/ai_core_freq/ai_core_freq_dataset.py b/profiler/advisor/dataset/ai_core_freq/ai_core_freq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..db31c1a0c5f37467c8191fdb2dc419b925ee4bd5 --- /dev/null +++ b/profiler/advisor/dataset/ai_core_freq/ai_core_freq_dataset.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import math +import os +import traceback + +import ijson +from tqdm import tqdm + +from profiler.advisor.common import constant as const +from profiler.advisor.common.timeline.event import TimelineEvent +from profiler.advisor.utils.utils import get_file_path_from_directory +from profiler.advisor.utils.utils import convert_to_float, parse_json_with_generator +from profiler.advisor.dataset.profiling.device_info import DeviceInfoParser +from profiler.advisor.config.config import Config + +logger = logging.getLogger() + + +class AICoreFreqDataset: + + def __init__(self, collection_path, data: dict, build_dataset=True, **kwargs) -> None: + + self._profiler_step = [] + self._ai_core_ops = [] + self._ai_core_freq: [TimelineEvent] = [] + self._previous_freq_index = -1 + + self.timeline_dir = collection_path + self.timeline_data_list = get_file_path_from_directory(collection_path, + lambda file: file.endswith("trace_view.json")) + + self.step = kwargs.get("step") + self.op_freq = {} + info = DeviceInfoParser(collection_path) + info.parse_data() + if not Config().get_config("aic_frequency"): + return + if self.parse(): + key = self.get_key() + if key not in data: + data[key] = [] + data[key].append(self) + + @property + def profiler_step(self): + return self._profiler_step + + @property + def ai_core_freq(self): + return self._ai_core_freq + + @property + def ai_core_ops(self): + return self._ai_core_ops + + @classmethod + def get_key(cls): + """ + get key of dataset + :return: key + """ + return cls.__module__.rsplit('.', maxsplit=1)[-1] + + def parse(self): + + if len(self.timeline_data_list) == 0: + logger.warning("Please ensure trace_view.json in %s, skip timeline analysis.", self.timeline_dir) + return False + + if len(self.timeline_data_list) > 1: + logger.warning("Found multiple trace_view.json in %s, load the file of device 0 for analysis.", + self.timeline_dir) + + _ = parse_json_with_generator(sorted(self.timeline_data_list)[0], self._add_event) + + target_ai_core_ops = self._get_target_ai_core_ops() + self._get_op_frequency(target_ai_core_ops) + return True + + def _add_profiler_step(self, event): + if event.name.startswith("ProfilerStep"): + self._profiler_step.append(event) + + def _add_ai_core_ops(self, event): + if event.args.get("Task Type") in ["MIX_AIC", "AI_CORE"]: + self._ai_core_ops.append(event) + + def _add_ai_core_freq(self, event): + if event.name == "AI Core Freq": + if self._previous_freq_index != -1: + self._ai_core_freq[self._previous_freq_index]["end"] = event.get("ts", float(math.inf)) + self._previous_freq_index += 1 + event.setdefault("end", float(math.inf)) + self._ai_core_freq.append(event) + + def _add_event(self, index, event): + event["dataset_index"] = index + if not isinstance(event, TimelineEvent): + event = TimelineEvent(event) + + self._add_profiler_step(event) + self._add_ai_core_ops(event) + self._add_ai_core_freq(event) + + return True + + def _get_target_ai_core_ops(self): + target_ai_core_ops = [] + if not self.step or f"ProfilerStep#{self.step}" not in [event.name for event in self._profiler_step]: + target_ai_core_ops = self._ai_core_ops + else: + for step_event in self._profiler_step: + if step_event.name != f"ProfilerStep#{self.step}": + continue + + for ai_core_op_event in self._ai_core_ops: + if step_event.ts_include(ai_core_op_event): + target_ai_core_ops.append(ai_core_op_event) + target_ai_core_ops = sorted(target_ai_core_ops, key=lambda x: float(x.ts)) + return target_ai_core_ops + + def _get_op_frequency(self, ai_core_ops): + ai_core_freq = sorted(self._ai_core_freq, key=lambda x: float(x.ts)) + + op_index, freq_index = 0, 0 + while op_index < len(ai_core_ops) and freq_index < len(ai_core_freq): + op_event = ai_core_ops[op_index] + op_end_time = convert_to_float(op_event.ts) + convert_to_float(op_event.dur) + op_freq_list = [] + while freq_index < len(ai_core_freq): + freq_event = ai_core_freq[freq_index] + if convert_to_float(freq_event.end) < op_end_time: + op_freq_list.append(convert_to_float(freq_event.args.MHz)) + freq_index += 1 + continue + elif convert_to_float(freq_event.ts) < op_end_time: + if op_event.name not in self.op_freq: + self.op_freq[op_event.name] = {"count": 0, "dur": 0, "freq_list": []} + self.op_freq[op_event.name]["count"] += 1 + self.op_freq[op_event.name]["dur"] += convert_to_float(op_event.dur) + op_freq_list.append(convert_to_float(freq_event.args.MHz)) + self.op_freq[op_event.name]["freq_list"].append(min(op_freq_list)) + break + else: + break + + op_index += 1 diff --git a/profiler/advisor/dataset/cluster/cluster_dataset.py b/profiler/advisor/dataset/cluster/cluster_dataset.py index 09fda2d4dcf2df2f05abb0007befb5c5c36ef824..66bf993a2f1f8f2798857a2389dd3468239e6a00 100644 --- a/profiler/advisor/dataset/cluster/cluster_dataset.py +++ b/profiler/advisor/dataset/cluster/cluster_dataset.py @@ -1,15 +1,31 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os +import re +from collections import defaultdict from profiler.advisor.dataset.dataset import Dataset from profiler.advisor.utils.utils import singleton from profiler.cluster_analyse.common_func.file_manager import FileManager from profiler.advisor.common import constant as const from profiler.cluster_analyse.common_func.constant import Constant -from collections import defaultdict from profiler.cluster_analyse.cluster_analysis import Interface from profiler.advisor.dataset.cluster.cluster_step_trace_time_bean import ClusterStepTraceTimeBean +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo logger = logging.getLogger() @@ -23,11 +39,11 @@ class ClusterDataset(Dataset): """ check whether input path is valid """ - for file in os.listdir(self.collection_path): - if file == 'cluster_analysis_output': - print("[INFO]Cluster has been analyzed " - "because of the existence of cluster analysis output directory.") - print("[INFO]Skip Cluster analyze backend.") + for filename in os.listdir(self.collection_path): + if filename == 'cluster_analysis_output': + logger.info("Cluster has been analyzed " + "because of the existence of cluster analysis output directory.") + logger.info("Skip Cluster analyze backend.") return True return False @@ -38,18 +54,18 @@ class ClusterDataset(Dataset): Constant.COLLECTION_PATH: self.collection_path, Constant.ANALYSIS_MODE: "all" } - print("[INFO] cluster analysis is in the process, please wait...") + logger.info("cluster analysis is in the process, please wait...") try: Interface(parameter).run() except Exception as e: raise ValueError(f"Cluster analyze backend failed:{e}") from e - def load_csv_data(self, file_name, dataBean): + def load_csv_data(self, file_name, data_bean): csv_path = os.path.join(self.collection_path, const.CLUSTER_ANALYSIS_OUTPUT, file_name) if not os.path.exists(csv_path): msg = "[ERROR] cluster_step_trace_time.csv doesn't exist, terminate analysis." raise RuntimeError(msg) - data = FileManager.read_csv_file(csv_path, dataBean) + data = FileManager.read_csv_file(csv_path, data_bean) return data def load_json_data(self, file_name): @@ -62,39 +78,58 @@ class ClusterDataset(Dataset): @singleton -class ClusterStepTraceTimeDataSet(ClusterDataset): +class ClusterStepTraceTimeDataset(ClusterDataset): RANK = "rank" + STAGE = "stage" def __init__(self, collection_path: str, data: dict, **kwargs): self._step_dict = defaultdict() + self._stages = [] super().__init__(collection_path, data) + def format_data(self, step_data: list): + step_dict = defaultdict(lambda: [0, 0, 0]) + for step_bean in step_data: + if step_bean.type == self.RANK: + step_rank_record = [] + step = str(step_bean.step).replace(" ", "") or str(const.DEFAULT_STEP) + rank = str(step_bean.index).replace(" ", "") + if step: + step_rank_record.append(step) + if rank: + step_rank_record.append(rank) + + step_rank_index = const.STEP_RANK_SEP.join(step_rank_record) + step_dict[step_rank_index][0] += step_bean.compute + step_dict[step_rank_index][1] += step_bean.communication + step_dict[step_rank_index][2] += step_bean.free + if step_bean.type == self.STAGE: + stage = sorted(list(map(int, re.findall(r'\d+', step_bean.stage)))) + if stage in self._stages: + continue + self._stages.append(stage) + return step_dict + + def get_data(self): + return self._step_dict + + def get_stages(self): + return sorted(self._stages) + def _parse(self): self.cluster_analyze() try: step_data = self.load_csv_data(const.CLUSTER_STEP_TIME_CSV, ClusterStepTraceTimeBean) except RuntimeError as e: - print("捕获到异常:", e) + logger.error("捕获到异常:%s", e) self._step_dict = None return False - self._step_dict = self.formate_data(step_data) + self._step_dict = self.format_data(step_data) return True - def formate_data(self, step_data: list): - step_dict = defaultdict(lambda: [0, 0, 0]) - for step_bean in step_data: - if step_bean.type == self.RANK: - step_dict[step_bean.index][0] += step_bean.compute - step_dict[step_bean.index][1] += step_bean.communication - step_dict[step_bean.index][2] += step_bean.free - return step_dict - - def get_data(self): - return self._step_dict - @singleton -class ClusterCommunicationDataSet(ClusterDataset): +class ClusterCommunicationDataset(ClusterDataset): RDMA_TIME_MS = "RDMA time(ms)" RDMA_SIZE_MB = "RDMA size(mb)" SDMA_TIME_MS = "SDMA time(ms)" @@ -108,12 +143,8 @@ class ClusterCommunicationDataSet(ClusterDataset): RDMA = "RDMA" def __init__(self, collection_path: str, data: dict, **kwargs): - self.rank_bw_dict = defaultdict(lambda: { - self.RDMA_TIME_MS: 0, - self.RDMA_SIZE_MB: 0, - self.SDMA_TIME_MS: 0, - self.SDMA_SIZE_MB: 0, - }) + self.rank_bw_dict = defaultdict(self.create_rank_bw_dict) + self.hccl_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) super().__init__(collection_path, data) @staticmethod @@ -122,25 +153,39 @@ class ClusterCommunicationDataSet(ClusterDataset): return 0 else: return round(dividend / divisor, 4) - - def _parse(self): - self.cluster_analyze() - try: - communication_json = self.load_json_data(const.CLUSTER_COMM_JSON) - except RuntimeError as e: - print("捕获到异常:", e) - self.rank_bw_dict = None - return False - self.process(communication_json) - return True + + def create_rank_bw_dict(self): + return{ + self.RDMA_TIME_MS: 0, + self.RDMA_SIZE_MB: 0, + self.SDMA_TIME_MS: 0, + self.SDMA_SIZE_MB: 0, + } def process(self, communication_json: dict): for comm_group, group_dict in communication_json.items(): + if self.hccl_dict.get(comm_group) is None: + self.hccl_dict.setdefault(comm_group, defaultdict(lambda: defaultdict(list))) for step, step_dict in group_dict.items(): for op, op_dict in step_dict.items(): - self.compute_bandwidth(op_dict) + self.compute_bandwidth(step.lower().lstrip("step") or str(const.DEFAULT_STEP), op_dict) + self.process_hccl_info(comm_group, step, op, op_dict) - def compute_bandwidth(self, op_dict: dict): + def process_hccl_info(self, group, step, op, op_dict): + op_name = op.split("@")[0] + for rank_id, rank_dict in op_dict.items(): + try: + hccl_info = HcclInfo(group, step, rank_id, op, rank_dict) + if self.hccl_dict[group].get(op_name) is None: + self.hccl_dict[group].setdefault(op_name, defaultdict(list)) + if self.hccl_dict[group][op_name].get(step) is None: + self.hccl_dict[group][op_name].setdefault(step, list()) + self.hccl_dict[group][op_name][step].append(hccl_info) + except ValueError as e: + msg = "[ERROR] Cluster_communication.json has invalid structure." + raise ValueError(msg) from e + + def compute_bandwidth(self, step, op_dict: dict): for rank_id, rank_dict in op_dict.items(): try: rank = int(rank_id) @@ -149,17 +194,32 @@ class ClusterCommunicationDataSet(ClusterDataset): raise ValueError(msg) from e for comm_type, bw_dict in rank_dict.get(self.COMMUNICATION_BANDWIDTH_INFO, {}).items(): if comm_type == self.SDMA: - self.rank_bw_dict[rank][self.SDMA_SIZE_MB] += bw_dict.get(self.TRANSIT_SIZE) - self.rank_bw_dict[rank][self.SDMA_TIME_MS] += bw_dict.get(self.TRANSIT_TIME) + self.rank_bw_dict[f"{step}{const.STEP_RANK_SEP}{rank}"][self.SDMA_SIZE_MB] += \ + bw_dict.get(self.TRANSIT_SIZE) + self.rank_bw_dict[f"{step}{const.STEP_RANK_SEP}{rank}"][self.SDMA_TIME_MS] += \ + bw_dict.get(self.TRANSIT_TIME) if comm_type == self.RDMA: - self.rank_bw_dict[rank][self.RDMA_SIZE_MB] += bw_dict.get(self.TRANSIT_SIZE) - self.rank_bw_dict[rank][self.RDMA_TIME_MS] += bw_dict.get(self.TRANSIT_TIME) + self.rank_bw_dict[f"{step}{const.STEP_RANK_SEP}{rank}"][self.RDMA_SIZE_MB] += \ + bw_dict.get(self.TRANSIT_SIZE) + self.rank_bw_dict[f"{step}{const.STEP_RANK_SEP}{rank}"][self.RDMA_TIME_MS] += \ + bw_dict.get(self.TRANSIT_TIME) - for rank, rank_dict in self.rank_bw_dict.items(): - self.rank_bw_dict[rank][self.RDMA_BANDWIDTH] = self.compute_ratio( - self.rank_bw_dict[rank][self.RDMA_SIZE_MB], self.rank_bw_dict[rank][self.RDMA_TIME_MS]) - self.rank_bw_dict[rank][self.SDMA_BANDWIDTH] = self.compute_ratio( - self.rank_bw_dict[rank][self.SDMA_SIZE_MB], self.rank_bw_dict[rank][self.SDMA_TIME_MS]) + for step_rank in self.rank_bw_dict.keys(): + self.rank_bw_dict[step_rank][self.RDMA_BANDWIDTH] = self.compute_ratio( + self.rank_bw_dict[step_rank][self.RDMA_SIZE_MB], self.rank_bw_dict[step_rank][self.RDMA_TIME_MS]) + self.rank_bw_dict[step_rank][self.SDMA_BANDWIDTH] = self.compute_ratio( + self.rank_bw_dict[step_rank][self.SDMA_SIZE_MB], self.rank_bw_dict[step_rank][self.SDMA_TIME_MS]) def get_data(self): return self.rank_bw_dict + + def _parse(self): + self.cluster_analyze() + try: + communication_json = self.load_json_data(const.CLUSTER_COMM_JSON) + except RuntimeError as e: + logger.error("捕获到异常:%s", e) + self.rank_bw_dict = None + return False + self.process(communication_json) + return True diff --git a/profiler/advisor/dataset/cluster/cluster_step_trace_time_bean.py b/profiler/advisor/dataset/cluster/cluster_step_trace_time_bean.py index b108fc77a3f3408d48c79ce6b542f98427d88b0b..8ae0e55f2a5fbc05304fd95809e9b69220dfd3e5 100644 --- a/profiler/advisor/dataset/cluster/cluster_step_trace_time_bean.py +++ b/profiler/advisor/dataset/cluster/cluster_step_trace_time_bean.py @@ -65,3 +65,6 @@ class ClusterStepTraceTimeBean: msg = "[ERROR] Cluster step trace time.csv has invalid value in column 'Free'." raise ValueError(msg) from e + @property + def stage(self) -> int: + return self._data.get(self.INDEX) diff --git a/profiler/advisor/dataset/cluster/hccl_collection.py b/profiler/advisor/dataset/cluster/hccl_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7156e496db05da19f30d7c049794046d2ceebb --- /dev/null +++ b/profiler/advisor/dataset/cluster/hccl_collection.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +hccl info +""" +import logging + +logger = logging.getLogger() + + +class HcclInfo(): + def __init__(self, group: str, step: str, rank: str, op: str, rank_dict: dict) -> None: + self._group = group + self._step = step + self._rank = rank + self._name = op.split("@")[0] + self._ts = self.get_communication_time_info(rank_dict, "Start Timestamp(us)") + self._elapse_time = self.get_communication_time_info(rank_dict, "Elapse Time(ms)") + self._sdma_info = self.get_communication_info(rank_dict, "SDMA") + self._rdma_info = self.get_communication_info(rank_dict, "RDMA") + + @property + def group(self): + return self._group + + @property + def step(self): + return self._step + + @property + def rank(self): + return self._rank + + @property + def name(self): + return self._name + + @property + def rdma_info(self): + return self._rdma_info + + @property + def sdma_info(self): + return self._sdma_info + + @property + def elapse_time(self): + return self._elapse_time + + @property + def ts(self): + return self._ts + + @staticmethod + def get_communication_info(rank_dict: dict, name: str): + communication_bandwidth_info = rank_dict.get('Communication Bandwidth Info', dict()) + return communication_bandwidth_info.get(name, dict()) + + @staticmethod + def get_communication_time_info(rank_dict: dict, name: str): + communication_time_info = rank_dict.get('Communication Time Info', dict()) + return communication_time_info.get(name, 0) + + def get_rdma_transmit_time(self): + return self.rdma_info.get('Transit Time(ms)', 0) + + def get_rdma_transit_size(self): + return self.rdma_info.get('Transit Size(MB)', 0) + + def get_rdma_bandwidth(self): + return self.rdma_info.get('Bandwidth(GB/s)', 0) diff --git a/profiler/advisor/dataset/communication/__init__.py b/profiler/advisor/dataset/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/dataset/communication/communication_dataset.py b/profiler/advisor/dataset/communication/communication_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..01a72ef93044ad8f0afb5af4ee864a99c7865060 --- /dev/null +++ b/profiler/advisor/dataset/communication/communication_dataset.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from collections import defaultdict +from profiler.advisor.utils.utils import singleton +from profiler.advisor.common import constant as const +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo +from profiler.advisor.utils.utils import CheckPathAccess + +logger = logging.getLogger() + + +@singleton +class CommunicationDataset: + RANK = "rank" + + def __init__(self, collection_path, data: dict, **kwargs) -> None: + self.timeline_dir = collection_path + if not self.timeline_dir.endswith("ascend_pt"): + return + self.timeline_data_list = self.get_file_path_from_directory( + self.timeline_dir, + lambda file: file.endswith(const.COMMUNICATION_JSON) + ) + self.hccl_dict = defaultdict(list) + self.step = kwargs.get("step") + if self.parse(): + key = self.get_key() + if key not in data: + data[key] = [] + data[key].append(self) + + @staticmethod + def load_json_data(json_path): + if not os.path.exists(json_path): + msg = "[ERROR] cluster_communication.json doesn't exist, terminate analysis." + raise RuntimeError(msg) + data = FileManager.read_json_file(json_path) + return data + + @staticmethod + @CheckPathAccess + def get_file_path_from_directory(path, check_func): + """ + get file from directory + """ + file_list = [] + + if not path: + return file_list + + if not os.path.isdir(path): + logger.warning("Expected existed directory, but got %s", path) + + for root, _, files in os.walk(path): + if root.endswith("cluster_analysis_output"): + continue + for filename in files: + filepath = os.path.join(root, filename) + if check_func(filename): + file_list.append(filepath) + return file_list + + @classmethod + def get_key(cls): + """ + get key of dataset + :return: key + """ + return cls.__module__.rsplit('.', maxsplit=1)[-1] + + def parse(self): + if len(self.timeline_data_list) == 0: + logger.warning("Please ensure communication.json in %s, skip timeline analysis.", self.timeline_dir) + return False + + if len(self.timeline_data_list) > 1: + logger.warning("Found multiple communication.json in %s, load the file of device 0 for analysis.", + self.timeline_dir) + + json_data = self.load_json_data(sorted(self.timeline_data_list)[0]) + self.process(json_data) + return True + + def process(self, communication_json: dict): + for step, step_dict in communication_json.items(): + for group, group_dict in step_dict.items(): + for op, op_dict in group_dict.items(): + self.process_hccl_info(group, step, op, op_dict) + + def process_hccl_info(self, group, step, op, op_dict): + try: + hccl_info = HcclInfo(group, step, "None", op, op_dict) + if self.hccl_dict.get(step) is None: + self.hccl_dict.setdefault(step, list()) + self.hccl_dict[step].append(hccl_info) + except ValueError as e: + msg = "[ERROR] Cluster_communication.json has invalid structure." + raise ValueError(msg) from e diff --git a/profiler/advisor/dataset/dataset.py b/profiler/advisor/dataset/dataset.py index 7f1e40a38b8a4a26585eecfe6271cc75ea054d2d..becd3e6e88d89326b2dcdacdd58add2ee150c17b 100644 --- a/profiler/advisor/dataset/dataset.py +++ b/profiler/advisor/dataset/dataset.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ dataset module """ @@ -26,7 +41,8 @@ class Dataset: data[key] = [] data[key].append(self) - def _parse(self): + @staticmethod + def _parse(): return None @classmethod diff --git a/profiler/advisor/dataset/environment_variable_dataset.py b/profiler/advisor/dataset/environment_variable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..577273ffe8ae955ae8b33e1d871ef2f867aa3f71 --- /dev/null +++ b/profiler/advisor/dataset/environment_variable_dataset.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import logging + +from profiler.advisor.common import constant +from profiler.cluster_analyse.common_func.file_manager import FileManager + + +class EnvironmentVariableDataset: + def __init__(self, collection_path, data: dict, **kwargs): + self.collection_path = collection_path + self.env_data = {} + self.read_data() + + @staticmethod + def get_env_data_file(collection_path: str) -> str: + for root, _, files in os.walk(collection_path): + for file_name in files: + if file_name == constant.PROFILER_METADATA: + return os.path.join(root, file_name) + return "" + + @classmethod + def get_key(cls): + return cls.__module__.rsplit('.', maxsplit=1)[-1] + + def read_data(self): + data_path = self.get_env_data_file(self.collection_path) + if not data_path: + return + try: + self.env_data = FileManager.read_json_file(data_path) + except RuntimeError as e: + logging.error("Read json failed. %s", str(e)) diff --git a/profiler/advisor/dataset/graph_dataset.py b/profiler/advisor/dataset/graph_dataset.py index 951de7fd26b1f986d25285547e63b1a420968249..d02af46b3a71413d3871b765fb878d7b46a958d5 100644 --- a/profiler/advisor/dataset/graph_dataset.py +++ b/profiler/advisor/dataset/graph_dataset.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging from typing import List @@ -19,17 +34,6 @@ class GraphDataset(Dataset): self.graph_files: List[HostGraphParser] = [] super().__init__(collection_path, data) - def _parse(self): - graph_list = get_file_path_from_directory(self.collection_path, - lambda file: file.endswith( - load_parameter(self.FILE_PATTERN, "_Build.txt"))) - - for graph_file_path in graph_list[-1:]: - logger.info("Prepare to parse %s as default graph.", graph_file_path) - graph_file = HostGraphParser(graph_file_path) - self.graph_files.append(graph_file) - return self.graph_files - @lazy_property def graphs(self) -> List[Graph]: """ @@ -51,3 +55,13 @@ class GraphDataset(Dataset): def is_empty(self) -> bool: """check empty graph dataset""" return len(self.graph_files) == 0 + + def _parse(self): + def is_matching_file(file): + return file.endswith(load_parameter(self.FILE_PATTERN, "Build.txt")) + graph_list = get_file_path_from_directory(self.collection_path, is_matching_file) + for graph_file_path in graph_list[-1:]: + logger.info("Prepare to parse %s as default graph.", graph_file_path) + graph_file = HostGraphParser(graph_file_path) + self.graph_files.append(graph_file) + return self.graph_files diff --git a/profiler/advisor/dataset/profiling/builder_base.py b/profiler/advisor/dataset/profiling/builder_base.py index 2bfe14f9462b701db2a4ede1d539a07659f48ae8..77bd926f72c5942e0c234e787cc59b3f6e572319 100644 --- a/profiler/advisor/dataset/profiling/builder_base.py +++ b/profiler/advisor/dataset/profiling/builder_base.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ profiling base """ diff --git a/profiler/advisor/dataset/profiling/db_manager.py b/profiler/advisor/dataset/profiling/db_manager.py index c9fb73c7cf69d94c3ca1aba8c726f574d63cd1a3..6e1dfbf58c4376af5e1745f4aa86ea9e786d3621 100644 --- a/profiler/advisor/dataset/profiling/db_manager.py +++ b/profiler/advisor/dataset/profiling/db_manager.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ connection manager """ @@ -34,6 +49,21 @@ class ConnectionManager: return False return True + @classmethod + def get_connection(cls, path, dbs, tables=None, is_host=False): + """ + get connection + """ + if is_host: + pattern = r"/device_[0-9]" + path = re.sub(pattern, "/host", path) + if not cls.check_db_exists(path, dbs): + return None + conn = cls(path, dbs) + if tables and not conn.check_table_exists(tables): + return None + return conn + def check_table_exists(self, tables:List) -> bool: """ check table exists @@ -53,18 +83,3 @@ class ConnectionManager: if column not in self.metadata.tables[table_name].columns: return False return True - - @classmethod - def get_connection(cls, path, dbs, tables=None, is_host=False): - """ - get connection - """ - if is_host: - pattern = r"/device_[0-9]" - path = re.sub(pattern, "/host", path) - if not cls.check_db_exists(path, dbs): - return None - conn = cls(path, dbs) - if tables and not conn.check_table_exists(tables): - return None - return conn diff --git a/profiler/advisor/dataset/profiling/device_info.py b/profiler/advisor/dataset/profiling/device_info.py index b58930777f969d023eab7885a9095d46aa7ba6ea..abb0e6000c4b0a5b10517a5789f3bbb6c47a6aa6 100644 --- a/profiler/advisor/dataset/profiling/device_info.py +++ b/profiler/advisor/dataset/profiling/device_info.py @@ -6,6 +6,7 @@ import logging from profiler.advisor.config.config import Config from profiler.advisor.utils.utils import get_file_path_from_directory +from profiler.cluster_analyse.common_func.file_manager import FileManager logger = logging.getLogger() @@ -40,9 +41,8 @@ class DeviceInfoParser: if info_file.endswith("done"): return False # skip info.json.0.done try: - with open(info_file, encoding="utf-8") as file: - info = json.load(file) - except (IOError, ValueError) as error: + info = FileManager.read_json_file(info_file) + except RuntimeError as error: logger.error("Parse json info file %s failed : %s", info_file, error) return False if "DeviceInfo" not in info: @@ -54,6 +54,8 @@ class DeviceInfoParser: config.set_config("device_id", device_info["id"]) if "aiv_num" in device_info: config.set_config("aiv_num", device_info["aiv_num"]) + if "aic_frequency" in device_info: + config.set_config("aic_frequency", device_info["aic_frequency"]) if "ai_core_num" in device_info: config.set_config("ai_core_num", device_info["ai_core_num"]) return True diff --git a/profiler/advisor/dataset/profiling/info_collection.py b/profiler/advisor/dataset/profiling/info_collection.py index b1f84313bb7980ea2186d2727db51b5fba49e12e..a3810dd0fcb2feb59c769c6119cbed1335da6793 100644 --- a/profiler/advisor/dataset/profiling/info_collection.py +++ b/profiler/advisor/dataset/profiling/info_collection.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ profiling info """ @@ -14,6 +29,7 @@ class Info: op info """ _attr_pre_fix_list = [""] + FFTS_TYPE = "ffts_type" def add_attr(self, key: str, value: str): """ @@ -116,7 +132,7 @@ class OpInfo(Info): if hasattr(self, attr): try: if float(getattr(self, attr)) > 0: - if hasattr(self, "ffts_type") and getattr(self, "ffts_type") == "1": + if hasattr(self, self.FFTS_TYPE) and getattr(self, self.FFTS_TYPE) == "1": logger.warning( "ffts type of op %s is vector buf mac ratio is not 0", getattr(self, "op_name") ) @@ -124,7 +140,7 @@ class OpInfo(Info): except ValueError: pass # not cube op - if hasattr(self, "ffts_type") and getattr(self, "ffts_type") == "0": + if hasattr(self, self.FFTS_TYPE) and getattr(self, self.FFTS_TYPE) == "0": logger.warning("ffts type of op %s is cube but mac ratio is 0", getattr(self, "op_name")) return False diff --git a/profiler/advisor/dataset/profiling/profiling_dataset.py b/profiler/advisor/dataset/profiling/profiling_dataset.py index 46d4a4fe8b12a419f6d0d7472f9776369e122f03..0db673e18871266316d2f5c3673021aee801d8ec 100644 --- a/profiler/advisor/dataset/profiling/profiling_dataset.py +++ b/profiler/advisor/dataset/profiling/profiling_dataset.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import os @@ -7,49 +22,54 @@ from profiler.advisor.common.profiling.ge_info import GeInfo from profiler.advisor.common.profiling.msprof import Msprof from profiler.advisor.common.profiling.op_summary import OpSummary from profiler.advisor.common.profiling.tasktime import TaskTime +from profiler.advisor.common.enum_params_parser import EnumParamsParser from profiler.advisor.dataset.dataset import Dataset from profiler.advisor.dataset.profiling.device_info import DeviceInfoParser from profiler.advisor.utils.utils import join_prof_path +from profiler.cluster_analyse.common_func.file_manager import FileManager logger = logging.getLogger() class ProfilingDataset(Dataset): - PROF_TYPE = "" + prof_type = "" def __init__(self, collection_path, data: dict, **kwargs) -> None: - self.cann_version = kwargs.get("cann_version", constant.DEFAULT_CANN_VERSION) - self.PROF_TYPE = kwargs.get("profiling_type", constant.DEFAULT_PROFILING_TYPE) + self.cann_version = kwargs.get(constant.CANN_VERSION, EnumParamsParser().get_default(constant.CANN_VERSION)) + self.prof_type = kwargs.get(constant.PROFILING_TYPE, EnumParamsParser().get_default(constant.PROFILING_TYPE)) self.patterns = self.parse_pattern() self.current_version_pattern = self.get_current_version_pattern() + self._info = None super().__init__(collection_path, data) - def _parse(self): - info = DeviceInfoParser(self.collection_path) - if info.parse_data(): - self._info = info - ret = False - if self.current_version_pattern is not None: - self.build_from_pattern(self.current_version_pattern["dirs_pattern"], self.collection_path) - ret = True - - return ret - - def build_from_pattern(self, dirs_pattern, current_path): + def build_from_pattern(self, dirs_pattern, current_path, depth): + if depth > constant.DEPTH_LIMIT: + logger.error("Recursion depth exceeds limit!") + return + depth += 1 if isinstance(dirs_pattern, dict): for key, value in dirs_pattern.items(): - self.build_from_pattern(value, join_prof_path(current_path, key)) + self.build_from_pattern(value, join_prof_path(current_path, key), depth) elif isinstance(dirs_pattern, list): for item in dirs_pattern: + if hasattr(self, item) and getattr(self, item): + # 避免重复构建kernel_details.csv, op_summary.csv的数据对象 + continue + file_pattern_list = self.current_version_pattern.get('file_attr').get(item) data_class = globals()[self.current_version_pattern.get('class_attr').get(item)] - data_class.FILE_PATTERN = self.current_version_pattern.get('file_attr').get(item) + if not hasattr(data_class, "file_pattern_list"): + continue + setattr(data_class, "file_pattern_list", self.current_version_pattern.get('file_attr').get(item)) data_object = data_class(current_path) is_success = data_object.parse_data() if is_success: setattr(self, item, data_object) else: - logger.warning("Skip parse %s from local path %s", self.current_version_pattern.get('class_attr').get(item), current_path) + logger.info("Skip parse %s with file pattern %s from local path %s", + self.current_version_pattern.get('class_attr').get(item), + file_pattern_list, current_path + ) else: logger.warning(f"Unsupported arguments : %s to build %s", dirs_pattern, self.__class__.__name__) @@ -69,11 +89,22 @@ class ProfilingDataset(Dataset): logger.warning("Skip parse profiling dataset, because %s does not exist.", config_path) return [] - with open(config_path, 'r') as f: - patterns = yaml.safe_load(f) + patterns = FileManager.read_yaml_file(config_path) - return patterns + return patterns if patterns else [] def collection_path(self): """collection_path""" return self.collection_path + + def _parse(self): + info = DeviceInfoParser(self.collection_path) + if info.parse_data(): + self._info = info + ret = False + dirs_pattern = self.current_version_pattern.get("dirs_pattern") + if dirs_pattern is not None: + self.build_from_pattern(dirs_pattern, self.collection_path, 0) + ret = True + + return ret diff --git a/profiler/advisor/dataset/profiling/profiling_parser.py b/profiler/advisor/dataset/profiling/profiling_parser.py index bb4caeb29e5c94cbc4373b1d6b10e32f3e10e02e..9f0f476de040ec78c3de39c82ff449b998525f41 100644 --- a/profiler/advisor/dataset/profiling/profiling_parser.py +++ b/profiler/advisor/dataset/profiling/profiling_parser.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import csv import json import os @@ -6,47 +21,35 @@ from typing import List, Dict from profiler.advisor.dataset.profiling.info_collection import logger from profiler.advisor.utils.utils import get_file_path_from_directory, SafeOpen, format_excel_title +from profiler.cluster_analyse.common_func.file_manager import FileManager class ProfilingParser: """ profiling """ - FILE_PATTERN = "" FILE_PATTERN_MSG = "" FILE_INFO = "" - FILE_PATH = "" + + file_pattern_list = [] def __init__(self, path: str) -> None: self._path = path - self._raw_data: List[List[str]] = [] + self._raw_data: Dict = dict() self._filename = "" + @property + def path(self): + """ + path + """ + return self._path + @staticmethod def file_match_func(pattern): """file match function""" return lambda x: re.search(re.compile(pattern), x) - def parse_data(self) -> bool: - """ - pase task time file - :return: true or false - """ - if self._parse_from_file(): - return True - return False - - def _parse_from_file(self): - file_list = get_file_path_from_directory(self._path, self.file_match_func(self.FILE_PATTERN)) - if not file_list: - return False - ## get last file - file = file_list[-1] - self.FILE_PATH = file - if len(file_list) > 1: - logger.warning("Multiple copies of %s were found, use %s", self.FILE_INFO, file) - return self.parse_from_file(file) - @staticmethod def get_float(data) -> float: """ @@ -57,12 +60,6 @@ class ProfilingParser: except (FloatingPointError, ValueError): return 0.0 - def parse_from_file(self, file): - """ - parse from file - """ - return False - @staticmethod def _check_csv_file_format(csv_file_name: str, csv_content: List[List[str]]): if not csv_content: @@ -70,8 +67,66 @@ class ProfilingParser: return False return True + @staticmethod + def _get_csv_title(data: List, number=0, title_index=0): + """ + number = 0 replace (us) (ns).. + other replace " " to "_" + title_index: position of title default 0 + """ + title_dict: Dict[int, str] = {} + for idx, title in enumerate(data[title_index]): + if number == 0: + title_dict[idx] = format_excel_title(title) + else: + title_dict[idx] = title.replace(" ", "_") + return title_dict + + @staticmethod + def parse_from_file(file): + """ + parse from file as a static method + """ + # 实现解析文件的逻辑,这里可以根据需要进行扩展 + return False + + def parse_data(self) -> bool: + """ + Parse task time file + :return: true or false + """ + if self._parse_from_file(): + return True + return False + + def get_raw_data(self): + """ + get raw file name and data + """ + return self._filename, self._raw_data + + def _parse_from_file(self): + if not isinstance(self.file_pattern_list, list): + self.file_pattern_list = [self.file_pattern_list] + + for file_pattern in self.file_pattern_list: + file_list = get_file_path_from_directory(self._path, self.file_match_func(file_pattern)) + if not file_list: + continue + # get last file + target_file = file_list[-1] + if len(file_list) > 1: + logger.warning("Multiple copies of %s were found, use %s", self.FILE_INFO, target_file) + return self.parse_from_file(target_file) + return False + def _parse_csv(self, file, check_csv=True) -> bool: logger.debug("Parse file %s", file) + try: + FileManager.check_file_size(file) + except RuntimeError as e: + logger.error("File size check failed: %s", e) + return False self._filename = os.path.splitext(os.path.basename(file))[0] with SafeOpen(file, encoding="utf-8") as csv_file: try: @@ -96,37 +151,8 @@ class ProfilingParser: logger.debug("Parse file %s", file) self._filename = os.path.splitext(os.path.basename(file))[0] try: - with open(file, encoding="utf-8") as json_file: - self._raw_data = json.load(json_file) - except (OSError, ValueError) as error: + self._raw_data = FileManager.read_json_file(file) + except RuntimeError as error: logger.error("Parse json file %s failed : %s", file, error) return False - return True - - def get_raw_data(self): - """ - get raw file name and data - """ - return self._filename, self._raw_data - - @staticmethod - def _get_csv_title(data: List, number=0, title_index=0): - """ - number = 0 replace (us) (ns).. - other replace " " to "_" - title_index: position of title default 0 - """ - title_dict: Dict[int, str] = {} - for idx, title in enumerate(data[title_index]): - if number == 0: - title_dict[idx] = format_excel_title(title) - else: - title_dict[idx] = title.replace(" ", "_") - return title_dict - - @property - def path(self): - """ - path - """ - return self._path + return True \ No newline at end of file diff --git a/profiler/advisor/dataset/timeline_event_dataset.py b/profiler/advisor/dataset/timeline_event_dataset.py index 94b6fdfef78c044e37e24772699ed7ea67b0da30..1ee2573c4c1077d12fc7546dff47ba34ea10b6df 100644 --- a/profiler/advisor/dataset/timeline_event_dataset.py +++ b/profiler/advisor/dataset/timeline_event_dataset.py @@ -1,220 +1,224 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect import logging -from typing import List +import traceback +from collections import OrderedDict import ijson -from profiler.advisor.dataset.dataset import Dataset from tqdm import tqdm from profiler.advisor.common import constant as const from profiler.advisor.common.timeline.event import TimelineEvent -from profiler.advisor.utils.utils import get_file_path_from_directory -from profiler.advisor.utils.utils import singleton +from profiler.advisor.utils.utils import get_file_path_from_directory, check_path_valid, singleton, convert_to_float +from profiler.advisor.dataset.timeline_op_collector.timeline_op_collector import ( + OpCompileCollector, + SynchronizeStreamCollector, + MemCollector, + DataloaderCollector, + SyncBNCollector, + AtenCollector, + OptimizerCollector, + FrequencyCollector, + SpecificTaskTypeOpCollector, + TorchToNpuCollector, + AclToNpuCollector, + OpStackCollector, + StepCollector, + GcCollector, + FreeEventsCollector, + AclEventsCollector +) logger = logging.getLogger() -class OpCompileCollector: - def __init__(self): - self._total_op_compile_counter = 0 - self._total_op_compile_time = 0.0 +class BaseTimelineEventDataset: + PROFILER_STEP_PREFIX = "ProfilerStep" - @property - def total_time(self): - return self._total_op_compile_time + collector_map = {} - @property - def total_count(self): - return self._total_op_compile_counter - - def is_empty(self): - return self._total_op_compile_counter == 0 - - def update(self, event: TimelineEvent): - self._total_op_compile_time += float(event.dur) - self._total_op_compile_counter += 1 - - def unset(self): - self._total_op_compile_counter = 0 - self._total_op_compile_time = 0.0 - - -@singleton -class TimelineEventDataset(Dataset): - - def __init__(self, collection_path, data: dict, **kwargs) -> None: - self._ops_with_task_type = {} - self._ops_with_stack = {} - self._ops_compile = OpCompileCollector() - self._torch_to_npu = {} - self._acl_to_npu = set() - self._aten: List[str] = [] - self._optimizer: List[str] = [] + def __init__(self, collection_path, data: dict, build_dataset=True, **kwargs) -> None: self.timeline_dir = collection_path - self.timeline_data_list = get_file_path_from_directory(collection_path, lambda file: file.endswith("trace_view.json")) + self.profiler_step = [] + self.timeline_data_list = get_file_path_from_directory(collection_path, + lambda file: file.endswith("trace_view.json")) self.dataset_len = None - self.analysis_mode = kwargs.get("analysis_mode") - self.task_type = kwargs.get("task_type") - self.cann_version = kwargs.get("cann_version") - self.torch_version = kwargs.get("torch_version") - - if self.analysis_mode in ["fusion_ops", "all"]: - logger.info("Load fusion operators database for cann version '%s' and torch version '%s'", - self.cann_version, self.torch_version) - - super().__init__(collection_path, data) - - if self.analysis_mode in ["op_stack", "all"]: - self._task_op_names = list(set([event_key.split("-")[0] for event_key in self._ops_with_task_type.keys()])) - - self._post_process() - - - @property - def ops_with_stack(self): - return self._ops_with_stack - - @property - def ops_compile(self): - return self._ops_compile - - @property - def torch_to_npu(self): - return self._torch_to_npu - - @property - def acl_to_npu(self): - return self._acl_to_npu - - @property - def ops_with_task_type(self): - return self._ops_with_task_type - - @property - def task_op_names(self): - return self._task_op_names + self.step = kwargs.get("step") + self.step_duration = 0.0 + if not build_dataset: + return - @property - def optimizer(self): - return self._optimizer + if self.parse(): + key = self.get_key() + if key not in data: + data[key] = [] + data[key].append(self) + + @classmethod + def get_key(cls): + """ + get key of dataset + :return: key + """ + return cls.__module__.rsplit('.', maxsplit=1)[-1] + + def get_post_process_kwargs(self, func_name): + kwargs = {} + if func_name == FrequencyCollector.__name__: + ops_with_task_type = getattr(self, "ops_with_task_type", {}).values() + kwargs["ai_core_ops"] = [ + op for op in ops_with_task_type if + op.get(const.TASK_TYPE) in [const.AI_CORE, const.MIX_AIC] + ] + return kwargs + + def add_event(self, index, event): + event["dataset_index"] = index + if not isinstance(event, TimelineEvent): + event = TimelineEvent(event) - @property - def aten(self): - return self._aten + for _, collector in self.collector_map.items(): + collector.add_op(event) + return True - def _parse(self): + def parse(self): if len(self.timeline_data_list) == 0: logger.warning("Please ensure trace_view.json in %s, skip timeline analysis.", self.timeline_dir) return False if len(self.timeline_data_list) > 1: - logger.warning("Please ensure only one trace_view.json in %s, there will analyze first timeline profiling data.", self.timeline_dir) - self.timeline_data_list = [self.timeline_data_list[0]] + logger.warning("Found multiple trace_view.json in %s, load the file of device 0 for analysis .", + self.timeline_dir) - result = self.parse_data_with_generator(self._add_event) + result = self.parse_data_with_generator(self.add_event) if not self.dataset_len: self.dataset_len = len(result) - return True def parse_data_with_generator(self, func): result = [] + timeline_data_path = sorted(self.timeline_data_list)[0] + if not check_path_valid(timeline_data_path): + return result + try: - with open(self.timeline_data_list[0], "r") as f: + with open(timeline_data_path, "r") as f: for i, event in tqdm(enumerate(ijson.items(f, "item")), leave=False, ncols=100, desc="Building dataset for timeline analysis", total=self.dataset_len): func_res = func(index=i, event=event) if func_res is not None: result.append(func_res) - except Exception as e: - logger.warning("Error %s while parsing file %s, continue to timeline analysis", e, - self.timeline_data_list[0]) - return result - def _add_ops_with_task_type(self, event): - key = f"{event.name}-{event.ts}" - self._ops_with_task_type[key] = TimelineEvent( - { - const.TASK_TYPE: event.args.get(const.TASK_TYPE), - "task_id": event.args.get("Task Id"), - "tid": event.tid, - "name": event.name, - "ts": str(event.ts) - } - ) - - def _add_ops_with_stack(self, event): - self._ops_with_stack[str(event.ts)] = TimelineEvent({"name": event.name, "dataset_index": event.dataset_index}) - - def _add_torch_to_npu(self, event): - key = f"{event.ph}-{event.id}" - self._torch_to_npu[key] = TimelineEvent({"tid": event.tid, "ts": str(event.ts)}) - - def _add_acl_to_npu(self, event): - # op with task type equals to ai_cpu which derived from acl_to_npu do not have stacks - self._acl_to_npu.add(str(event.ts)) - - def _add_op_compile(self, event: TimelineEvent): - if event.name == const.OP_COMPILE_NAME or event.args.get("id") == const.OP_COMPILE_ID: - self._ops_compile.update(event) - - def _add_optimizer(self, event: TimelineEvent): - self._optimizer.append(TimelineEvent({"name": event.name, "dataset_index": event.dataset_index})) - - def _add_aten(self, event: TimelineEvent): - self._aten.append(TimelineEvent({ - "name": event.name, "dataset_index": event.dataset_index, "ts": event.ts, "dur": event.dur - })) - - def _add_event(self, index, event): - event["dataset_index"] = index - if not isinstance(event, TimelineEvent): - event = TimelineEvent(event) + except Exception: + logger.warning("Error %s while parsing file %s, continue to timeline analysis", traceback.format_exc(), + timeline_data_path) + return result - self._add_op_compile(event) - if self.analysis_mode == "fusion_ops": - self._add_event_for_fusion_ops(event) - elif self.analysis_mode == "op_stack": - self._add_event_for_op_stack(event) + def _get_target_ops_by_step(self, op_list): + target_ops = [] + if not self.profiler_step: + return op_list + if not self.step or f"ProfilerStep#{self.step}" not in [event.name for event in self.profiler_step]: + target_ops = op_list + if self.profiler_step: + self.step_duration = convert_to_float(self.profiler_step[-1].dur) else: - self._add_event_for_fusion_ops(event) - self._add_event_for_op_stack(event) - return True + for step_event in self.profiler_step: + if step_event.name != f"ProfilerStep#{self.step}": + continue + self.step_duration = convert_to_float(step_event.dur) + for op_event in op_list: + if step_event.ts_include(op_event): + target_ops.append(op_event) + target_ops.sort(key=lambda x: convert_to_float(x.ts)) + return target_ops + + def _collector_post_process(self): + # 按step过滤collector中的算子,并将过滤后的算子设置为当前dataset的property,与原始TimelineEventDataset的property保持一致 + for collector_name, collector in self.collector_map.items(): + logger.debug("Start post process for operator collector: %s", collector_name) + if collector.require_filter_by_step: + logger.debug("Operator Collector %s requires filter ops by step %s", collector_name, self.step) + target_op_list = self._get_target_ops_by_step(collector.op_list) + else: + logger.debug("Operator Collector %s use operators of all step for analysis", collector_name) + target_op_list = collector.op_list + + logger.debug("Source number of ops is %s, number of ops after filtered by rank is %s", + len(collector.op_list), len(target_op_list)) + + collector_kwargs = self.get_post_process_kwargs(collector_name) + collector.post_process(target_op_list, **collector_kwargs) + for property_name, property_value in collector.attribute_to_dataset.items(): + setattr(self, property_name, property_value) - def _add_event_for_fusion_ops(self, event): - if event.name.lower().startswith(f"{const.ATEN}{const.ATEN_SEP}") or event.name.lower().startswith( - f"{const.NPU}{const.ATEN_SEP}"): - self._add_aten(event) - return - if event.name.startswith(f"{const.OPTIMIZER}.{const.OPTIMIZER_STEP}{const.OPTIMIZER_SEP}"): - self._add_optimizer(event) - return +@singleton +class ScheduleAnalysisDataset(BaseTimelineEventDataset): + collector_map = OrderedDict( + StepCollector=StepCollector(), + MemCollector=MemCollector(), + OpCompileCollector=OpCompileCollector(), + SynchronizeStreamCollector=SynchronizeStreamCollector(), + DataloaderCollector=DataloaderCollector(), + SyncBNCollector=SyncBNCollector(), + AtenCollector=AtenCollector(), + OptimizerCollector=OptimizerCollector(), + GcCollector=GcCollector(), + FreeEventsCollector=FreeEventsCollector(), + AclEventsCollector=AclEventsCollector() + ) + + def __init__(self, collection_path, data: dict, build_dataset=True, **kwargs) -> None: + super().__init__(collection_path, data, build_dataset, **kwargs) + self.aten = None + self.synchronize_stream = None + self._collector_post_process() + self._post_process() - def _add_event_for_op_stack(self, event): - if event.name.lower() == const.TORCH_TO_NPU: - self._add_torch_to_npu(event) + def _post_process(self): + # eliminate sub aten operator of the first level aten operator by 'ts' and 'dur', + # keep the first level aten operator contiguous + formated_atens = [] + if not hasattr(self, "aten"): return - if event.args.get(const.CALL_STACKS): - self._add_ops_with_stack(event) - return + for event in sorted(self.aten, key=lambda x: x.get("ts", -1)): + if event.name.startswith(const.ATEN): + if not formated_atens or not formated_atens[-1].ts_include(event): + formated_atens.append(event) - if event.args.get(const.TASK_TYPE) and event.args.get(const.TASK_TYPE) in [const.AI_CORE, const.AI_CPU]: - self._add_ops_with_task_type(event) - return + self.aten = formated_atens - if event.name and event.ts and event.name == const.ACL_TO_NPU: - self._add_acl_to_npu(event) - return - def _post_process(self): - # eliminate sub aten operator of the first level aten operator by 'ts' and 'dur', - # keep the first level aten operator contiguous - formated_atens = [] - for aten_event in sorted(self._aten, key=lambda x: x.get("ts", -1)): - if not formated_atens or not formated_atens[-1].ts_include(aten_event): - formated_atens.append(aten_event) - self._aten = formated_atens +@singleton +class ComputationAnalysisDataset(BaseTimelineEventDataset): + collector_map = OrderedDict( + StepCollector=StepCollector(), + SpecificTaskTypeOpCollector=SpecificTaskTypeOpCollector(), + TorchToNpuCollector=TorchToNpuCollector(), + AclToNpuCollector=AclToNpuCollector(), + OpStackCollector=OpStackCollector(), + FrequencyCollector=FrequencyCollector(), + ) + + def __init__(self, collection_path, data: dict, build_dataset=True, **kwargs) -> None: + super().__init__(collection_path, data, build_dataset, **kwargs) + self._collector_post_process() diff --git a/profiler/advisor/dataset/timeline_op_collector/__init__.py b/profiler/advisor/dataset/timeline_op_collector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/dataset/timeline_op_collector/timeline_op_collector.py b/profiler/advisor/dataset/timeline_op_collector/timeline_op_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..aa44bb6c0f0fa667aad34ed760ebdc5b9af13965 --- /dev/null +++ b/profiler/advisor/dataset/timeline_op_collector/timeline_op_collector.py @@ -0,0 +1,398 @@ +import logging +import math +import os +from abc import abstractmethod, ABCMeta + +from profiler.advisor.common import constant as const +from profiler.advisor.common.timeline.event import TimelineEvent +from profiler.advisor.utils.utils import convert_to_float +from profiler.cluster_analyse.common_func.file_manager import FileManager + +logger = logging.getLogger() + + +class BaseOpCollector(metaclass=ABCMeta): + + def __init__(self): + self.attribute_to_dataset = {} + self.op_list = [] + self.require_filter_by_step = True + + @abstractmethod + def add_op(self): + """ add timeline event into self.op_list, and then will filter event in self.op_list by specific step + """ + pass + + @abstractmethod + def post_process(self): + """ convert self.op_list to required format like dict, set and so on and then record the final object into + self.attribute_to_dataset which used to set property of timeline event dataset + """ + pass + + +class StepCollector(BaseOpCollector): + KEY_WORD = "ProfilerStep" + + def __init__(self): + super().__init__() + self.require_filter_by_step = False + + def add_op(self, event): + if event.name.startswith(self.KEY_WORD): + self.op_list.append(event) + + def post_process(self, *args, **kwargs): + self.attribute_to_dataset["profiler_step"] = self.op_list + + +class OpCompileCollector(BaseOpCollector): + def __init__(self): + super().__init__() + self._total_op_compile_counter = 0 + self._total_op_compile_time = 0.0 + + @property + def total_time(self): + return self._total_op_compile_time + + @property + def total_count(self): + return self._total_op_compile_counter + + def is_empty(self): + return self._total_op_compile_counter == 0 + + def update(self, event: TimelineEvent): + self._total_op_compile_time += float(event.dur) + self._total_op_compile_counter += 1 + + def unset(self): + self._total_op_compile_counter = 0 + self._total_op_compile_time = 0.0 + + def add_op(self, event): + if event.name == const.OP_COMPILE_NAME or event.args.get("id") == const.OP_COMPILE_ID: + self.op_list.append(event) + + def post_process(self, target_op_list, **kwargs): + for op in target_op_list: + self.update(op) + + self.attribute_to_dataset["ops_compile"] = self + + +class SynchronizeStreamCollector(BaseOpCollector): + + def __init__(self): + super().__init__() + self.require_filter_by_step = False + + def add_op(self, event): + if event.name.startswith(const.SYNC_STREAM) or event.name.startswith(const.NODE_LAUNCH): + self.op_list.append(event) + + def post_process(self, *args, **kwargs): + self.op_list.sort(key=lambda x: x.ts) + + self.attribute_to_dataset["synchronize_stream"] = self.op_list + + +class MemCollector(BaseOpCollector): + MEMORY_OP_NAME = ["AscendCL@aclMallocMemInner", "AscendCL@aclrtFreePhysical", "AscendCL@aclrtFree"] + + def __init__(self): + super().__init__() + self.mem_op_info = {} + self.rule = self._load_rule() + + @staticmethod + def _load_rule(): + memory_rule_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "memory.yaml") + + memory_rule = FileManager.read_yaml_file(memory_rule_path) + return memory_rule + + def add_op(self, event): + if event.name not in self.MEMORY_OP_NAME: + return + self.op_list.append(event) + + def post_process(self, target_op_list, **kwargs): + for op in target_op_list: + if op.name not in self.mem_op_info: + self.mem_op_info[op.name] = dict(count=0, total_dur=0) + self.mem_op_info[op.name]["count"] += 1 + self.mem_op_info[op.name]["total_dur"] += float(op.dur) + + self.attribute_to_dataset["memory_ops"] = self + + +class DataloaderCollector(BaseOpCollector): + key_word = "dataloader" + + def __init__(self): + super().__init__() + + def add_op(self, event): + if self.key_word in event.name.lower(): + self.op_list.append(TimelineEvent({ + "name": event.name, "dataset_index": event.dataset_index, "ts": event.ts, "dur": event.dur, + "stack": event.args.get("Call stack") + })) + + def post_process(self, *args, **kwargs): + self.attribute_to_dataset["dataloader"] = self.op_list + + +class SyncBNCollector(BaseOpCollector): + key_word = "syncbatchnorm" + + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.name.lower() == self.key_word: + self.op_list.append(TimelineEvent({ + "name": event.name, "dataset_index": event.dataset_index, "ts": event.ts, "dur": event.dur + })) + + def post_process(self, target_op_list, **kwargs): + self.attribute_to_dataset["sync_batchnorm"] = target_op_list + + +class AtenCollector(BaseOpCollector): + + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.name.lower().startswith(f"{const.ATEN}{const.ATEN_SEP}") or event.name.lower().startswith( + f"{const.NPU}{const.ATEN_SEP}"): + self._add_aten(event) + return + + # 检查cann层同步操作,根据时间窗口索引到host侧的aten算子并给出堆栈 + if event.name.startswith(const.SYNC_STREAM): + self._add_aten(event) + + def post_process(self, target_op_list, **kwargs): + self.attribute_to_dataset["aten"] = target_op_list + + def _add_aten(self, event: TimelineEvent): + self.op_list.append(TimelineEvent({ + "name": event.name, "dataset_index": event.dataset_index, "ts": event.ts, "dur": event.dur + })) + + +class OptimizerCollector(BaseOpCollector): + + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.name.startswith(f"{const.OPTIMIZER}.{const.OPTIMIZER_STEP}{const.OPTIMIZER_SEP}"): + self.op_list.append(TimelineEvent( + {"name": event.name, "dataset_index": event.dataset_index, "ts": event.ts, "dur": event.dur})) + + def post_process(self, target_op_list, **kwargs): + self.attribute_to_dataset["optimizer"] = target_op_list + + +class FrequencyCollector(BaseOpCollector): + KEY_WORD = "AI Core Freq" + + def __init__(self): + super().__init__() + self._previous_freq_index = -1 + + @staticmethod + def get_op_frequency(ai_core_ops, ai_core_freq): + ai_core_freq.sort(key=lambda x: float(x.ts)) + op_freq_record = {} + + op_index, freq_index = 0, 0 + while op_index < len(ai_core_ops) and freq_index < len(ai_core_freq): + op_event = ai_core_ops[op_index] + op_end_time = convert_to_float(op_event.ts) + convert_to_float(op_event.dur) + op_freq_list = [] + while freq_index < len(ai_core_freq): + freq_event = ai_core_freq[freq_index] + if convert_to_float(freq_event.end) < op_end_time: + op_freq_list.append(convert_to_float(freq_event.args.MHz)) + freq_index += 1 + continue + elif convert_to_float(freq_event.ts) < op_end_time: + if op_event.name not in op_freq_record: + op_freq_record[op_event.name] = {"count": 0, "dur": 0, "freq_list": []} + op_freq_record[op_event.name]["count"] += 1 + op_freq_record[op_event.name]["dur"] += convert_to_float(op_event.dur) + op_freq_list.append(convert_to_float(freq_event.args.MHz)) + op_freq_record[op_event.name]["freq_list"].append(min(op_freq_list)) + break + else: + break + + op_index += 1 + return op_freq_record + + def add_op(self, event): + if event.name == self.KEY_WORD: + if self._previous_freq_index != -1: + self.op_list[self._previous_freq_index]["end"] = event.get("ts", float(math.inf)) + self._previous_freq_index += 1 + event.setdefault("end", float(math.inf)) + self.op_list.append(event) + + def post_process(self, target_op_list, **kwargs): + ai_core_ops = kwargs.get("ai_core_ops", []) + if not ai_core_ops: + return + ai_core_ops.sort(key=lambda x: float(x.ts)) + op_freq = FrequencyCollector.get_op_frequency(ai_core_ops, target_op_list) + self.attribute_to_dataset["op_freq"] = op_freq + + +class SpecificTaskTypeOpCollector(BaseOpCollector): + + def __init__(self, op_type_list=None): + super().__init__() + self.op_type_list = op_type_list if op_type_list else [const.AI_CPU, const.AI_CORE, const.MIX_AIC] + + def add_op(self, event): + if event.args.get(const.TASK_TYPE) and event.args.get(const.TASK_TYPE) in self.op_type_list: + self.op_list.append( + TimelineEvent( + { + const.TASK_TYPE: event.args.get(const.TASK_TYPE), + "task_id": event.args.get("Task Id"), + "tid": event.tid, + "name": event.name, + "ts": str(event.ts), + "dur": str(event.dur) + } + ) + ) + + def post_process(self, target_op_list, **kwargs): + op_map = dict() + for op in target_op_list: + key = f"{op.name}-{op.ts}" + op_map[key] = op + + self.attribute_to_dataset["ops_with_task_type"] = op_map + self.attribute_to_dataset["task_op_names"] = list( + set([event_key.split("-")[0] for event_key in op_map.keys()])) + + +class TorchToNpuCollector(BaseOpCollector): + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.name.lower() == const.TORCH_TO_NPU: + self.op_list.append(TimelineEvent({"tid": event.tid, "ts": str(event.ts), "ph": event.ph, "id": event.id})) + + def post_process(self, target_op_list, **kwargs): + op_map = dict() + for op in target_op_list: + key = f"{op.ph}-{op.id}" + op_map[key] = op + + self.attribute_to_dataset["torch_to_npu"] = op_map + + +class AclToNpuCollector(BaseOpCollector): + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.name and event.ts and event.name == const.ACL_TO_NPU: + self.op_list.append(TimelineEvent({"ts": event.ts})) + + def post_process(self, target_op_list, **kwargs): + op_record = set(str(op.ts) for op in target_op_list) + self.attribute_to_dataset["acl_to_npu"] = op_record + + +class OpStackCollector(BaseOpCollector): + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.args.get(const.CALL_STACKS): + self.op_list.append( + TimelineEvent({"name": event.name, "dataset_index": event.dataset_index, "ts": event.ts})) + + def post_process(self, target_op_list, **kwargs): + op_map = dict() + for op in target_op_list: + op_map[str(op.ts)] = op + + self.attribute_to_dataset["ops_with_stack"] = op_map + + +class GcCollector(BaseOpCollector): + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.cat and isinstance(event.cat, str) and event.cat.lower() == "gc": + self.op_list.append(TimelineEvent( + {"name": event.name, "dataset_index": event.dataset_index, "ts": event.ts, "dur": event.dur})) + + def post_process(self, target_op_list, **kwargs): + self.attribute_to_dataset["gc_events"] = self.op_list + + +class FreeEventsCollector(BaseOpCollector): + def __init__(self): + super().__init__() + + @staticmethod + def _load_rule(): + sync_stream_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "gc.yaml") + + gc_rule = FileManager.read_yaml_file(sync_stream_rule_path) + return gc_rule + + def add_op(self, event): + if event.name.lower() == const.FREE: + self.op_list.append(event) + + def post_process(self, target_op_list, **kwargs): + gc_rule = self._load_rule() + if os.getenv(const.FREE_DURATION_FOR_GC_ANALYSIS): + max_free_threshold = convert_to_float(os.getenv(const.FREE_DURATION_FOR_GC_ANALYSIS)) + else: + max_free_threshold = gc_rule.get("max_free_threshold") + + large_free_events = [] + + for op in target_op_list: + if convert_to_float(op.dur) > max_free_threshold: + large_free_events.append(op) + + large_free_events.sort(key=lambda x: convert_to_float(x.ts)) + self.attribute_to_dataset["large_free_events"] = large_free_events + + +class AclEventsCollector(BaseOpCollector): + ACL_EVENT_PREFIX = "AscendCL@" + + def __init__(self): + super().__init__() + + def add_op(self, event): + if event.name.startswith(self.ACL_EVENT_PREFIX): + self.op_list.append(event) + + def post_process(self, target_op_list, **kwargs): + target_op_list.sort(key=lambda x: convert_to_float(x.ts)) + self.attribute_to_dataset["acl_events"] = target_op_list diff --git a/profiler/advisor/display/html/priority_background_color.py b/profiler/advisor/display/html/priority_background_color.py new file mode 100644 index 0000000000000000000000000000000000000000..6b03747a81b532364816e171b846adac1f1883fa --- /dev/null +++ b/profiler/advisor/display/html/priority_background_color.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class PriorityBackgroundColor: + high = "#B5495B" + medium = "#fcaf17" + low = "#65c294" diff --git a/profiler/advisor/display/html/render.py b/profiler/advisor/display/html/render.py index 8ea7c9e0fc22c7da71a673e399fcfc231fbf1453..d20df9a7601f94c6a276abc2f17dff5c717ebbf3 100644 --- a/profiler/advisor/display/html/render.py +++ b/profiler/advisor/display/html/render.py @@ -1,6 +1,22 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import logging from typing import List, Dict +from collections import defaultdict, OrderedDict from jinja2 import Environment, FileSystemLoader from profiler.advisor.common import constant @@ -13,33 +29,74 @@ logger = logging.getLogger() @singleton class HTMLRender: + SUPPORTED_KEYS = [ + "main", "overall", "comparison", "computation", "schedule", "communication", "dataloader", + "memory", + ] + PERFORMANCE_PROBLEM_ANALYSIS = "performance_problem_analysis" + def __init__(self): self.html = "" - self.render_list: Dict[str, List] = {} + self.render_list = defaultdict(list) def render_html(self, template_dir: str = "templates", template_name: str = "main.html", template_header=constant.DEFAULT_TEMPLATE_HEADER): - self.html = self.render_template("main", template_dir, template_name, render_list=self.render_list, + + # 确保overall 和 comparison 在 performance problem analysis 之前 + sorted_render_htmls = OrderedDict() + for key in ["overall", "comparison"]: + if key in self.render_list: + sorted_render_htmls[key] = self.render_list.get(key) + for key, html in self.render_list.items(): + if key in sorted_render_htmls: + continue + sorted_render_htmls[key] = html + + self.html = self.render_template("main", template_dir, template_name, render_list=sorted_render_htmls, template_header=template_header) - def render_template(self, key: str, template_dir: str, template_name: str, **kwargs): + def get_rendered_html(self, key: str, template_dir: str, template_name: str, **kwargs): + if key not in self.SUPPORTED_KEYS: + error_msg = f"Error render template key {key}, optionals are {self.SUPPORTED_KEYS}" + logger.error(error_msg) + raise Exception(error_msg) + if not os.path.isabs(template_dir): template_dir = os.path.join(os.path.dirname(__file__), template_dir) env = Environment(loader=FileSystemLoader(template_dir), autoescape=True) template = env.get_template(template_name) + if "priority" not in kwargs: + kwargs["priority"] = "low priority" rendered_html = template.render(**kwargs) - if key not in self.render_list: - self.render_list[key] = [] - self.render_list[key].append(rendered_html) + return rendered_html + + def render_template(self, key: str, template_dir: str, template_name: str, **kwargs): + rendered_html = self.get_rendered_html(key, template_dir, template_name, **kwargs) + + if not kwargs.get("add_render_list", True): + return rendered_html + + if key in ["main", "overall", "comparison"]: + if key not in self.render_list: + self.render_list[key] = [] + self.render_list[key].append(rendered_html) + else: + if self.PERFORMANCE_PROBLEM_ANALYSIS not in self.render_list: + self.render_list[self.PERFORMANCE_PROBLEM_ANALYSIS] = {} + if key not in self.render_list[self.PERFORMANCE_PROBLEM_ANALYSIS]: + self.render_list[self.PERFORMANCE_PROBLEM_ANALYSIS][key] = [] + self.render_list[self.PERFORMANCE_PROBLEM_ANALYSIS][key].append(rendered_html) + return rendered_html def save_to_file(self, save_path: str): + save_path = os.path.join(Config().work_path, save_path) if not save_path.endswith(".html"): logger.error("Skip save html file because file name must endswith `.html`, " "but got %s.", os.path.basename(save_path)) return safe_write(self.html, save_path) - logger.info("Save suggestion to %s.", os.path.join(Config().work_path, save_path)) + logger.info("Save suggestion to %s.", save_path) diff --git a/profiler/advisor/display/html/templates/affinity_api.html b/profiler/advisor/display/html/templates/affinity_api.html index 4d12c3e37536392d122f85fc6ef3a4fcc123ef77..7cd3d7ad33d0220c7aba055721eddf049161a0d8 100644 --- a/profiler/advisor/display/html/templates/affinity_api.html +++ b/profiler/advisor/display/html/templates/affinity_api.html @@ -1,11 +1,14 @@ {% if result|length > 0 %}

-

Affinity API Issues

+

Affinity API Issues

+ {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} The analysis results of following affinity APIs are based on runtime env - cann-{{ cann_version }} + cann-{{ cann_version }} and - torch-{{ torch_version }} + torch-{{ torch_version }}
@@ -13,7 +16,7 @@ Suggestion: These APIs have no code stack. If parameter 'with_stack=False' was set while profiling, please refer to Ascend PyTorch Profiler to set - 'with_stack=True'. Otherwise, ignore following affinity APIs due to backward broadcast lack of stack. + 'with_stack=True'. Otherwise, ignore following affinity APIs due to backward broadcast lack of stack. {% endif %} {% for api_name, stacks in result.items() %} diff --git a/profiler/advisor/display/html/templates/ai_core_frequency.html b/profiler/advisor/display/html/templates/ai_core_frequency.html new file mode 100644 index 0000000000000000000000000000000000000000..405460ac9616740613bc337d705d617cc9de9287 --- /dev/null +++ b/profiler/advisor/display/html/templates/ai_core_frequency.html @@ -0,0 +1,30 @@ +{% if data|length > 0 %} +
+

AI CORE Frequency Issues

+
+ {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} + Issue: {{ desc }} +
+ Suggestion: {{ suggestion }} +

+ + + {% for header in headers %} + + {% endfor %} + + + {% for row in data %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
{{ header }}
{{ element|safe }}
+ +
+
+{% endif %} \ No newline at end of file diff --git a/profiler/advisor/display/html/templates/communication_retransmission_analysis.html b/profiler/advisor/display/html/templates/communication_retransmission_analysis.html new file mode 100644 index 0000000000000000000000000000000000000000..75754fde72467934ac92166ebec7ad5440e55896 --- /dev/null +++ b/profiler/advisor/display/html/templates/communication_retransmission_analysis.html @@ -0,0 +1,40 @@ +
+

Communication Retransmission Analysis

+
+ {{ desc }} + + + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + + {% endfor %} + {% endfor %} +
Suggestions
{{ rowloop.index }}. {{ value.desc }}
+

+ {{ desc }} + + + {% for header in headers %} + + {% endfor %} + + + {% for row in data %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
{{ header }}
{{ element|safe }}
+ +
+
diff --git a/profiler/advisor/display/html/templates/comparison.html b/profiler/advisor/display/html/templates/comparison.html new file mode 100644 index 0000000000000000000000000000000000000000..b81802d6b0505ca4a21e5174a0158b800d4a43ec --- /dev/null +++ b/profiler/advisor/display/html/templates/comparison.html @@ -0,0 +1,25 @@ +{% if rows|length > 0 %} +
+

{{ sheet_name }}

+
+ Issue: {{ desc }} +

+ + + {% for header in headers %} + + {% endfor %} + + + {% for row in rows %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
{{ header }}
{{ element|safe }}
+ +
+
+{% endif %} \ No newline at end of file diff --git a/profiler/advisor/display/html/templates/contention.html b/profiler/advisor/display/html/templates/contention.html new file mode 100644 index 0000000000000000000000000000000000000000..3d7fb89c5e0992524ba84dc13b9ba177905cf632 --- /dev/null +++ b/profiler/advisor/display/html/templates/contention.html @@ -0,0 +1,41 @@ +
+

Bandwidth Contention Analysis

+
+ {{ desc }} + + + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + + {% endfor %} + {% endfor %} +
Suggestions
{{ rowloop.index }}. {{ value.desc }}
+

+ The following table lists the {{topk}} operators with the + most severe performance deterioration. + + + {% for header in headers %} + + {% endfor %} + + + {% for row in data %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
{{ header }}
{{ element|safe }}
+ +
+
diff --git a/profiler/advisor/display/html/templates/environment_variable.html b/profiler/advisor/display/html/templates/environment_variable.html new file mode 100644 index 0000000000000000000000000000000000000000..ab95096393910e4c1c3a79d5a57640ccddc57928 --- /dev/null +++ b/profiler/advisor/display/html/templates/environment_variable.html @@ -0,0 +1,21 @@ +
+

Environment Variable Issues

+
+ + + {% for header in result.get("headers") %} + + {% endfor %} + + + {% for row in result.get("data") %} + + {% for value in row %} + + {% endfor %} + + {% endfor %} + +
{{ header }}
{{ value|safe }}
+
+
\ No newline at end of file diff --git a/profiler/advisor/display/html/templates/gc.html b/profiler/advisor/display/html/templates/gc.html new file mode 100644 index 0000000000000000000000000000000000000000..205e1b3b9ede3282189864f116a9c650b59626df --- /dev/null +++ b/profiler/advisor/display/html/templates/gc.html @@ -0,0 +1,42 @@ + +
+

GC Analysis

+
+ {% if rank is not none %} + Analysis of rank {{ rank|safe }}. + {% endif %} + {{ desc }} + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + {% endfor %} + {% endfor %} +
Suggestions
{{ rowloop.index }}. {{ value.desc }}
+ {% if datas|safe %} + The details of top {{ num }} garbage collection events are as follows: +

+ + + {% for header in headers %} + + {% endfor %} + + + {% for row in datas %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
{{ header }}
{{ element|safe }}
+ {% endif %} +
+
diff --git a/profiler/advisor/display/html/templates/main.html b/profiler/advisor/display/html/templates/main.html index 3727125b419547fc6a9ac9743eab34e1e1b76256..9317abba543dacabf19d6b7967acf496e7aa8dc9 100644 --- a/profiler/advisor/display/html/templates/main.html +++ b/profiler/advisor/display/html/templates/main.html @@ -137,10 +137,21 @@

Performance Optimization Suggestions

+ +
+ Optimization Priority: +
+ High +
+ Medium +
+ Low +
+ {% for key, renders in render_list.items() %} - {% if key == 'operator'%} + {% if key != 'performance_problem_analysis' %}
-

computation

+

{{ key }}

{% for render in renders %} {{render|safe}} @@ -148,14 +159,25 @@
{% else %} +
-

{{ key }}

+

performance problem analysis

- {% for render in renders %} - {{render|safe}} - {% endfor %} + + + {% for sub_key, sub_renders in renders.items() %} +
+

{{ sub_key }}

+
+ {% for render in sub_renders %} + {{render|safe}} + {% endfor %} +
+
+ {% endfor %}
+ {% endif %} {% endfor %}