代码拉取完成,页面将自动刷新
import sys
from functools import partial
from itertools import product
from typing import Dict
import numpy as np
import pytest
import torch
import kornia
from kornia.utils._compat import torch_version
try:
import torch._dynamo
_backends_non_experimental = torch._dynamo.list_backends()
except ImportError:
_backends_non_experimental = []
def get_test_devices() -> Dict[str, torch.device]:
"""Create a dictionary with the devices to test the source code. CUDA devices will be test only in case the
current hardware supports it.
Return:
dict(str, torch.device): list with devices names.
"""
devices: Dict[str, torch.device] = {}
devices["cpu"] = torch.device("cpu")
if torch.cuda.is_available():
devices["cuda"] = torch.device("cuda:0")
if kornia.xla_is_available():
import torch_xla.core.xla_model as xm
devices["tpu"] = xm.xla_device()
if hasattr(torch.backends, "mps"):
if torch.backends.mps.is_available():
devices["mps"] = torch.device("mps")
return devices
def get_test_dtypes() -> Dict[str, torch.dtype]:
"""Create a dictionary with the dtypes the source code.
Return:
dict(str, torch.dtype): list with dtype names.
"""
dtypes: Dict[str, torch.dtype] = {}
dtypes["bfloat16"] = torch.bfloat16
dtypes["float16"] = torch.float16
dtypes["float32"] = torch.float32
dtypes["float64"] = torch.float64
return dtypes
# setup the devices to test the source code
TEST_DEVICES: Dict[str, torch.device] = get_test_devices()
TEST_DTYPES: Dict[str, torch.dtype] = get_test_dtypes()
TEST_OPTIMIZER_BACKEND = {"", None, "jit", *_backends_non_experimental}
# Combinations of device and dtype to be excluded from testing.
# DEVICE_DTYPE_BLACKLIST = {('cpu', 'float16')}
DEVICE_DTYPE_BLACKLIST = {}
@pytest.fixture()
def device(device_name) -> torch.device:
return TEST_DEVICES[device_name]
@pytest.fixture()
def dtype(dtype_name) -> torch.dtype:
return TEST_DTYPES[dtype_name]
@pytest.fixture()
def torch_optimizer(optimizer_backend):
if not optimizer_backend:
return lambda x: x
if optimizer_backend == "jit":
return torch.jit.script
if hasattr(torch, "compile") and sys.platform == "linux":
if not (sys.version_info[:2] == (3, 11) and torch_version() in {"2.0.0", "2.0.1"}):
torch._dynamo.reset()
# torch compile just have support for python 3.11 after torch 2.1.0
return partial(
torch.compile, backend=optimizer_backend
) # TODO: explore the others parameters of torch compile
pytest.skip(f"skipped because {torch.__version__} not have `compile` available! Failed to setup dynamo.")
def pytest_generate_tests(metafunc):
device_names = None
dtype_names = None
optimizer_backends_names = None
if "device_name" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--device")
if raw_value == "all":
device_names = list(TEST_DEVICES.keys())
else:
device_names = raw_value.split(",")
if "dtype_name" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--dtype")
if raw_value == "all":
dtype_names = list(TEST_DTYPES.keys())
else:
dtype_names = raw_value.split(",")
if "optimizer_backend" in metafunc.fixturenames:
raw_value = metafunc.config.getoption("--optimizer")
if raw_value == "all":
optimizer_backends_names = TEST_OPTIMIZER_BACKEND
else:
optimizer_backends_names = raw_value.split(",")
if device_names is not None and dtype_names is not None and optimizer_backends_names is not None:
# Exclude any blacklisted device/dtype combinations.
params = [
combo
for combo in product(device_names, dtype_names, optimizer_backends_names)
if combo not in DEVICE_DTYPE_BLACKLIST
]
metafunc.parametrize("device_name,dtype_name,optimizer_backend", params)
elif device_names is not None and dtype_names is not None and optimizer_backends_names is None:
# Exclude any blacklisted device/dtype combinations.
params = [combo for combo in product(device_names, dtype_names) if combo not in DEVICE_DTYPE_BLACKLIST]
metafunc.parametrize("device_name,dtype_name", params)
elif device_names is not None and dtype_names is None and optimizer_backends_names is not None:
params = product(device_names, optimizer_backends_names)
metafunc.parametrize("device_name,optimizer_backend", params)
elif device_names is not None:
metafunc.parametrize("device_name", device_names)
elif dtype_names is not None:
metafunc.parametrize("dtype_name", dtype_names)
elif optimizer_backends_names is not None:
metafunc.parametrize("optimizer_backend", optimizer_backends_names)
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--dtype", action="store", default="float32")
parser.addoption("--optimizer", action="store", default="inductor")
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def _setup_torch_compile():
if hasattr(torch, "compile") and sys.platform == "linux":
print("Setting up torch compile...")
torch.set_float32_matmul_precision("high")
def _dummy_function(x, y):
return (x + y).sum()
class _dummy_module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return (x**2).sum()
torch.compile(_dummy_function)
torch.compile(_dummy_module())
def pytest_sessionstart(session):
try:
_setup_torch_compile()
except RuntimeError as ex:
if "not yet supported for torch.compile" not in str(ex):
raise ex
# TODO: cache all torch.load weights/states here to not impact on test suite
def _get_env_info() -> Dict[str, Dict[str, str]]:
if not hasattr(torch.utils, "collect_env"):
return {}
run_lmb = torch.utils.collect_env.run
separator = ":"
br = "\n"
def _get_key_value(v: str):
parts = v.split(separator)
return parts[0].strip(), parts[-1].strip()
def _get_cpu_info() -> Dict[str, str]:
cpu_info = {}
cpu_str = torch.utils.collect_env.get_cpu_info(run_lmb)
if not cpu_str:
return {}
for data in cpu_str.split(br):
key, value = _get_key_value(data)
cpu_info[key] = value
return cpu_info
def _get_gpu_info() -> Dict[str, str]:
gpu_info = {}
gpu_str = torch.utils.collect_env.get_gpu_info(run_lmb)
if not gpu_str:
return {}
for data in gpu_str.split(br):
key, value = _get_key_value(data)
gpu_info[key] = value
return gpu_info
return {
"cpu": _get_cpu_info(),
"gpu": _get_gpu_info(),
"nvidia": torch.utils.collect_env.get_nvidia_driver_version(run_lmb),
"gcc": torch.utils.collect_env.get_gcc_version(run_lmb),
}
def pytest_report_header(config):
try:
import accelerate
accelerate_info = f"accelerate-{accelerate.__version__}"
except ImportError:
accelerate_info = "`accelerate` not found"
import kornia_rs
import onnx
env_info = _get_env_info()
if "cpu" in env_info:
desired_cpu_info = ["Model name", "Architecture", "CPU(s)", "Thread(s) per core", "CPU max MHz", "CPU min MHz"]
cpu_info = "cpu info:\n" + "\n".join(
f'\t- {i}: {env_info["cpu"][i]}' for i in desired_cpu_info if i in env_info["cpu"]
)
else:
cpu_info = ""
gpu_info = f"gpu info: {env_info['gpu']}" if "gpu" in env_info else ""
gcc_info = f"gcc info: {env_info['gcc']}" if "gcc" in env_info else ""
return f"""
{cpu_info}
{gpu_info}
main deps:
- kornia-{kornia.__version__}
- torch-{torch.__version__}
- commit: {torch.version.git_version}
- cuda: {torch.version.cuda}
- nvidia-driver: {env_info['nvidia'] if 'nvidia' in env_info else None}
x deps:
- {accelerate_info}
dev deps:
- kornia_rs-{kornia_rs.__version__}
- onnx-{onnx.__version__}
{gcc_info}
available optimizers: {TEST_OPTIMIZER_BACKEND}
"""
@pytest.fixture(autouse=True)
def add_doctest_deps(doctest_namespace):
doctest_namespace["np"] = np
doctest_namespace["torch"] = torch
doctest_namespace["kornia"] = kornia
# the commit hash for the data version
sha: str = "cb8f42bf28b9f347df6afba5558738f62a11f28a"
sha2: str = "f7d8da661701424babb64850e03c5e8faec7ea62"
sha3: str = "8b98f44abbe92b7a84631ed06613b08fee7dae14"
@pytest.fixture(scope="session")
def data(request):
url = {
"loftr_homo": f"https://github.com/kornia/data_test/blob/{sha}/loftr_outdoor_and_homography_data.pt?raw=true",
"loftr_fund": f"https://github.com/kornia/data_test/blob/{sha}/loftr_indoor_and_fundamental_data.pt?raw=true",
"adalam_idxs": f"https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true",
"lightglue_idxs": f"https://github.com/kornia/data_test/blob/{sha2}/adalam_test.pt?raw=true",
"disk_outdoor": f"https://github.com/kornia/data_test/blob/{sha3}/knchurch_disk.pt?raw=true",
"dexined": "https://cmp.felk.cvut.cz/~mishkdmy/models/DexiNed_BIPED_10.pth",
}
return torch.hub.load_state_dict_from_url(url[request.param], map_location=torch.device("cpu"))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。