Ai
3 Star 6 Fork 0

Gitee 极速下载/viztracer

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/gaogaotiantian/viztracer
克隆/下载
test_torch.py 4.36 KB
一键复制 编辑 原始数据 按行查看 历史
Tian Gao 提交于 2025-10-27 14:50 +08:00 . Relax torch alignment test on Windows (#626)
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/gaogaotiantian/viztracer/blob/master/NOTICE.txt
import platform
import sys
from .cmdline_tmpl import CmdlineTmpl
from .package_env import package_matrix
def support_torch():
if "linux" in sys.platform:
return True
if sys.platform == "win32":
return sys.version_info < (3, 13)
if sys.platform == "darwin":
return platform.machine().lower() == "arm64" and sys.version_info < (3, 13)
@package_matrix(["~torch", "torch"] if support_torch() else ["~torch"])
class TestTorch(CmdlineTmpl):
def test_entry(self):
# We only want to install/uninstall torch once, so do all tests in one function
with self.subTest("basic"):
self.case_basic()
with self.subTest("cmdline"):
self.case_cmdline()
def case_basic(self):
assert self.pkg_config is not None
if self.pkg_config.has("torch"):
script = """
import torch
from viztracer import VizTracer
with VizTracer(log_torch=True, verbose=0):
for i in range(100):
torch.empty(i)
"""
def check_func(data):
events = data["traceEvents"]
py_events = [e for e in events if e["name"] == "torch.empty"]
aten_events = [e for e in events if e["name"] == "aten::empty"]
self.assertEqual(len(py_events), 100)
self.assertEqual(len(aten_events), 100)
for py, aten in zip(py_events, aten_events):
if "linux" in sys.platform:
# We care about Linux
self.assertLess(py["ts"], aten["ts"])
self.assertGreater(py["ts"] + py["dur"], aten["ts"] + aten["dur"])
elif sys.platform == "win32":
# Windows is at least sane, give it 50us diff
self.assertLess(py["ts"], aten["ts"] + 50)
self.assertGreater(py["ts"] + py["dur"], aten["ts"] + aten["dur"] - 50)
else:
# Mac is pure crazy and we don't care about it
pass
self.template([sys.executable, "cmdline_test.py"], script=script,
check_func=check_func)
else:
script = """
from viztracer import VizTracer
_ = VizTracer(log_torch=True, verbose=0)
"""
self.template([sys.executable, "cmdline_test.py"], script=script,
expected_output_file=None, success=False,
expected_stderr=".*ModuleNotFoundError.*")
def case_cmdline(self):
assert self.pkg_config is not None
if self.pkg_config.has("torch"):
script = """
import torch
for i in range(100):
torch.empty(i)
"""
def check_func(data):
events = data["traceEvents"]
py_events = [e for e in events if e["name"] == "torch.empty"]
aten_events = [e for e in events if e["name"] == "aten::empty"]
self.assertEqual(len(py_events), 100)
self.assertEqual(len(aten_events), 100)
for py, aten in zip(py_events, aten_events):
if "linux" in sys.platform:
# We care about Linux
self.assertLess(py["ts"], aten["ts"])
self.assertGreater(py["ts"] + py["dur"], aten["ts"] + aten["dur"])
elif sys.platform == "win32":
# Windows is at least sane, give it 100us diff
acceptable_margin = 100
self.assertLess(py["ts"], aten["ts"] + acceptable_margin)
self.assertGreater(py["ts"] + py["dur"], aten["ts"] + aten["dur"] - acceptable_margin)
else:
# Mac is pure crazy and we don't care about it
pass
self.template(["viztracer", "--log_torch", "cmdline_test.py"], script=script, check_func=check_func)
else:
self.template(["viztracer", "--log_torch", "cmdline_test.py"], script="pass", success=False)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C/C++
1
https://gitee.com/mirrors/viztracer.git
git@gitee.com:mirrors/viztracer.git
mirrors
viztracer
viztracer
master

搜索帮助