代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import threading
from torch._utils import ExceptionWrapper
import logging
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(fct, model, inputs, device_ids):
modules = nn.parallel.replicate(model, device_ids)
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = fct(module, *input)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker, args=(i, module, input))
for i, (module, input) in enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。