1 Star 0 Fork 0

Gao Xing/benchmark

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
test.py 2.34 KB
一键复制 编辑 原始数据 按行查看 历史
nikithamalgifb 提交于 2020-09-21 15:06 . Add DLRM Model (#112)
"""test.py
Setup and Run hub models.
Make sure to enable an https proxy if necessary, or the setup steps may hang.
"""
# This file shows how to use the benchmark suite from user end.
import argparse
import time
from torchbenchmark import list_models
from unittest import TestCase
import re, sys, unittest
import os.path
import torch
import gc
class TestBenchmark(TestCase):
def setUp(self):
gc.collect()
if 'cuda' in str(self):
self.memory = torch.cuda.memory_allocated()
def tearDown(self):
gc.collect()
if 'cuda' in str(self):
gc.collect()
memory = torch.cuda.memory_allocated()
self.assertEqual(self.memory, memory)
torch.cuda.empty_cache()
def run_model(model_class, model_path, device):
m = model_class(device=device)
def _load_test(model_class, device):
def model_object(self):
if device == 'cuda' and not torch.cuda.is_available():
self.skipTest("torch.cuda not available")
return model_class(device=device)
def example(self):
m = model_object(self)
try:
module, example_inputs = m.get_module()
module(*example_inputs)
except NotImplementedError:
self.skipTest('Method get_module is not implemented, skipping...')
def train(self):
m = model_object(self)
try:
start = time.time()
m.train()
print('Finished training on device: {} in {}s.'.format(device, time.time() - start))
except NotImplementedError:
self.skipTest('Method train is not implemented, skipping...')
def eval(self):
m = model_object(self)
try:
start = time.time()
m.eval()
print('Finished eval on device: {} in {}s.'.format(device, time.time() - start))
except NotImplementedError:
self.skipTest('Method eval is not implemented, skipping...')
setattr(TestBenchmark, f'test_{model_class.name}_example_{device}', example)
setattr(TestBenchmark, f'test_{model_class.name}_train_{device}', train)
setattr(TestBenchmark, f'test_{model_class.name}_eval_{device}', eval)
def _load_tests():
for Model in list_models():
for device in ('cpu', 'cuda'):
_load_test(Model, device)
_load_tests()
if __name__ == '__main__':
unittest.main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/clustertech/benchmark.git
git@gitee.com:clustertech/benchmark.git
clustertech
benchmark
benchmark
0.1

搜索帮助