3 Star 1 Fork 3

Gitee 极速下载 / Kornia

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/kornia/kornia
克隆/下载
conftest.py 9.78 KB
一键复制 编辑 原始数据 按行查看 历史
import sys
from functools import partial
from itertools import product
from typing import Dict
import numpy as np
import pytest
import torch
import kornia
from kornia.utils._compat import torch_version
try:
import torch._dynamo
_backends_non_experimental = torch._dynamo.list_backends()
except ImportError:
_backends_non_experimental = []
def get_test_devices() -> Dict[str, torch.device]:
"""Create a dictionary with the devices to test the source code. CUDA devices will be test only in case the
current hardware supports it.
Return:
dict(str, torch.device): list with devices names.
"""
devices: Dict[str, torch.device] = {}
devices["cpu"] = torch.device("cpu")
if torch.cuda.is_available():
devices["cuda"] = torch.device("cuda:0")
if kornia.xla_is_available():
import torch_xla.core.xla_model as xm
devices["tpu"] = xm.xla_device()
if hasattr(torch.backends, "mps"):
if torch.backends.mps.is_available():
devices["mps"] = torch.device("mps")
return devices
def get_test_dtypes() -> Dict[str, torch.dtype]:
"""Create a dictionary with the dtypes the source code.
Return:
dict(str, torch.dtype): list with dtype names.
"""
dtypes: Dict[str, torch.dtype] = {}
dtypes["bfloat16"] = torch.bfloat16
dtypes["float16"] = torch.float16
dtypes["float32"] = torch.float32
dtypes["float64"] = torch.float64
return dtypes
# setup the devices to test the source code
TEST_DEVICES: Dict[str, torch.device] = get_test_devices()
TEST_DTYPES: Dict[str, torch.dtype] = get_test_dtypes()
TEST_OPTIMIZER_BACKEND = {"", None, "jit", *_backends_non_experimental}
# Combinations of device and dtype to be excluded from testing.
# DEVICE_DTYPE_BLACKLIST = {('cpu', 'float16')}
DEVICE_DTYPE_BLACKLIST = {}
@pytest.fixture()
def device(device_name) -> torch.device:
return TEST_DEVICES[device_name]
@pytest.fixture()
def dtype(dtype_name) -> torch.dtype:
return TEST_DTYPES[dtype_name]
@pytest.fixture()
def torch_optimizer(optimizer_backend):
if not optimizer_backend:
return lambda x: x
if optimizer_backend == "jit":
return torch.jit.script
if hasattr(torch, "compile") and sys.platform == "linux":
if not (sys.version_info[:2] == (3, 11) and torch_version() in {"2.0.0", "2.0.1"}):
torch._dynamo.reset()
# torch compile just have support for python 3.11 after torch 2.1.0
return partial(
torch.compile, backend=optimizer_backend
) # TODO: explore the others parameters of torch compile
pytest.skip(f"skipped because {torch.__version__} not have `compile` available! Failed to setup dynamo.")
def pytest_generate_tests(metafunc):
device_names = None
dtype_names = None
optimizer_backends_names = None
if "device_name" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--device")
if raw_value == "all":
device_names = list(TEST_DEVICES.keys())
else:
device_names = raw_value.split(",")
if "dtype_name" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--dtype")
if raw_value == "all":
dtype_names = list(TEST_DTYPES.keys())
else:
dtype_names = raw_value.split(",")
if "optimizer_backend" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--optimizer")
if raw_value == "all":
optimizer_backends_names = TEST_OPTIMIZER_BACKEND
else:
optimizer_backends_names = raw_value.split(",")
if device_names is not None and dtype_names is not None and optimizer_backends_names is not None:
# Exclude any blacklisted device/dtype combinations.
params = [
combo
for combo in product(device_names, dtype_names, optimizer_backends_names)
if combo not in DEVICE_DTYPE_BLACKLIST
]
metafunc.parametrize("device_name,dtype_name,optimizer_backend", params)
elif device_names is not None and dtype_names is not None and optimizer_backends_names is None:
# Exclude any blacklisted device/dtype combinations.
params = [combo for combo in product(device_names, dtype_names) if combo not in DEVICE_DTYPE_BLACKLIST]
metafunc.parametrize("device_name,dtype_name", params)
elif device_names is not None and dtype_names is None and optimizer_backends_names is not None:
params = product(device_names, optimizer_backends_names)
metafunc.parametrize("device_name,optimizer_backend", params)
elif device_names is not None:
metafunc.parametrize("device_name", device_names)
elif dtype_names is not None:
metafunc.parametrize("dtype_name", dtype_names)
elif optimizer_backends_names is not None:
metafunc.parametrize("optimizer_backend", optimizer_backends_names)
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--dtype", action="store", default="float32")
parser.addoption("--optimizer", action="store", default="inductor")
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def _setup_torch_compile():
if hasattr(torch, "compile") and sys.platform == "linux":
print("Setting up torch compile...")
torch.set_float32_matmul_precision("high")
def _dummy_function(x, y):
return (x + y).sum()
class _dummy_module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return (x**2).sum()
torch.compile(_dummy_function)
torch.compile(_dummy_module())
def pytest_sessionstart(session):
try:
_setup_torch_compile()
except RuntimeError as ex:
if "not yet supported for torch.compile" not in str(ex):
raise ex
# TODO: cache all torch.load weights/states here to not impact on test suite
def _get_env_info() -> Dict[str, Dict[str, str]]:
if not hasattr(torch.utils, "collect_env"):
return {}
run_lmb = torch.utils.collect_env.run
separator = ":"
br = "\n"
def _get_key_value(v: str):
parts = v.split(separator)
return parts[0].strip(), parts[-1].strip()
def _get_cpu_info() -> Dict[str, str]:
cpu_info = {}
cpu_str = torch.utils.collect_env.get_cpu_info(run_lmb)
if not cpu_str:
return {}
for data in cpu_str.split(br):
key, value = _get_key_value(data)
cpu_info[key] = value
return cpu_info
def _get_gpu_info() -> Dict[str, str]:
gpu_info = {}
gpu_str = torch.utils.collect_env.get_gpu_info(run_lmb)
if not gpu_str:
return {}
for data in gpu_str.split(br):
key, value = _get_key_value(data)
gpu_info[key] = value
return gpu_info
return {
"cpu": _get_cpu_info(),
"gpu": _get_gpu_info(),
"nvidia": torch.utils.collect_env.get_nvidia_driver_version(run_lmb),
"gcc": torch.utils.collect_env.get_gcc_version(run_lmb),
}
def pytest_report_header(config):
try:
import accelerate
accelerate_info = f"accelerate-{accelerate.__version__}"
except ImportError:
accelerate_info = "`accelerate` not found"
import kornia_rs
import onnx
env_info = _get_env_info()
if "cpu" in env_info:
desired_cpu_info = ["Model name", "Architecture", "CPU(s)", "Thread(s) per core", "CPU max MHz", "CPU min MHz"]
cpu_info = "cpu info:\n" + "\n".join(
f'\t- {i}: {env_info["cpu"][i]}' for i in desired_cpu_info if i in env_info["cpu"]
)
else:
cpu_info = ""
gpu_info = f"gpu info: {env_info['gpu']}" if "gpu" in env_info else ""
gcc_info = f"gcc info: {env_info['gcc']}" if "gcc" in env_info else ""
return f"""
{cpu_info}
{gpu_info}
main deps:
- kornia-{kornia.__version__}
- torch-{torch.__version__}
- commit: {torch.version.git_version}
- cuda: {torch.version.cuda}
- nvidia-driver: {env_info['nvidia'] if 'nvidia' in env_info else None}
x deps:
- {accelerate_info}
dev deps:
- kornia_rs-{kornia_rs.__version__}
- onnx-{onnx.__version__}
{gcc_info}
available optimizers: {TEST_OPTIMIZER_BACKEND}
"""
@pytest.fixture(autouse=True)
def add_doctest_deps(doctest_namespace):
doctest_namespace["np"] = np
doctest_namespace["torch"] = torch
doctest_namespace["kornia"] = kornia
# the commit hash for the data version
sha: str = "cb8f42bf28b9f347df6afba5558738f62a11f28a"
sha2: str = "f7d8da661701424babb64850e03c5e8faec7ea62"
sha3: str = "8b98f44abbe92b7a84631ed06613b08fee7dae14"
@pytest.fixture(scope="session")
def data(request):
url = {
"loftr_homo": f"https://github.com/kornia/data_test/blob/{sha}/loftr_outdoor_and_homography_data.pt?raw=true",
"loftr_fund": f"https://github.com/kornia/data_test/blob/{sha}/loftr_indoor_and_fundamental_data.pt?raw=true",
"adalam_idxs": f"https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true",
"lightglue_idxs": f"https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true",
"disk_outdoor": f"https://github.com/kornia/data_test/blob/{sha3}/knchurch_disk.pt?raw=true",
"dexined": "https://cmp.felk.cvut.cz/~mishkdmy/models/DexiNed_BIPED_10.pth",
}
return torch.hub.load_state_dict_from_url(url[request.param], map_location=torch.device("cpu"))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/Kornia.git
git@gitee.com:mirrors/Kornia.git
mirrors
Kornia
Kornia
main

搜索帮助

344bd9b3 5694891 D2dac590 5694891