1 Star 0 Fork 0

BugCat/pytorch-i3d

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
charades_dataset_full.py 4.80 KB
一键复制 编辑 原始数据 按行查看 历史
BugCat 提交于 4年前 . 修改
import torch
import torch.utils.data as data_utl
from torch.utils.data.dataloader import default_collate
import numpy as np
import json
import csv
import h5py
import os
import os.path
import cv2
class_dict = {'BaseballPitch': 0,
'BasketballDunk': 1,
'Billiards': 2,
'CleanAndJerk': 3,
'CliffDiving': 4,
'CricketBowling': 5,
'CricketShot': 6,
'Diving': 7,
'FrisbeeCatch': 8,
'GolfSwing': 9,
'HammerThrow': 10,
'HighJump': 11,
'JavelinThrow': 12,
'LongJump': 13,
'PoleVault': 14,
'Shotput': 15,
'SoccerPenalty': 16,
'TennisSwing': 17,
'ThrowDiscus': 18,
'VolleyballSpiking': 19}
def video_to_tensor(pic):
"""Convert a ``numpy.ndarray`` to tensor.
Converts a numpy.ndarray (T x H x W x C)
to a torch.FloatTensor of shape (C x T x H x W)
Args:
pic (numpy.ndarray): Video to be converted to tensor.
Returns:
Tensor: Converted video.
"""
return torch.from_numpy(pic.transpose([3, 0, 1, 2]))
def load_rgb_frames(image_dir, vid, start, num):
frames = []
for i in range(start, start + num):
# img = cv2.imread(os.path.join(image_dir, vid, vid + '-' + str(i).zfill(6) + '.jpg'))[:, :, [2, 1, 0]]
img = cv2.imread(os.path.join(image_dir, vid, str(i) + '.jpg'))[:, :, [2, 1, 0]]
w, h, c = img.shape
if w < 226 or h < 226:
d = 226. - min(w, h)
sc = 1 + d / min(w, h)
img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
img = (img / 255.) * 2 - 1
frames.append(img)
return np.asarray(frames, dtype=np.float32)
def load_flow_frames(image_dir, vid, start, num):
frames = []
for i in range(start, start + num):
# imgx = cv2.imread(os.path.join(image_dir, vid, vid + '-' + str(i).zfill(6) + 'x.jpg'), cv2.IMREAD_GRAYSCALE)
# imgy = cv2.imread(os.path.join(image_dir, vid, vid + '-' + str(i).zfill(6) + 'y.jpg'), cv2.IMREAD_GRAYSCALE)
imgx = cv2.imread(os.path.join(image_dir, vid, str(i) + '_' + 'x.jpg'), cv2.IMREAD_GRAYSCALE)
imgy = cv2.imread(os.path.join(image_dir, vid, str(i) + '_' + 'y.jpg'), cv2.IMREAD_GRAYSCALE)
w, h = imgx.shape
if w < 224 or h < 224:
d = 224. - min(w, h)
sc = 1 + d / min(w, h)
imgx = cv2.resize(imgx, dsize=(0, 0), fx=sc, fy=sc)
imgy = cv2.resize(imgy, dsize=(0, 0), fx=sc, fy=sc)
imgx = (imgx / 255.) * 2 - 1
imgy = (imgy / 255.) * 2 - 1
img = np.asarray([imgx, imgy]).transpose([1, 2, 0])
frames.append(img)
return np.asarray(frames, dtype=np.float32)
def make_dataset(split_file, split, root, mode, num_classes=157):
dataset = []
with open(split_file, 'r') as f:
data = json.load(f)
i = 0
for vid in data.keys():
if data[vid]['subset'] != split:
continue
if not os.path.exists(os.path.join(root, vid)):
continue
num_frames = len(os.listdir(os.path.join(root, vid)))
if mode == 'flow':
num_frames = num_frames // 2
label = np.zeros((num_classes, num_frames), np.float32)
fps = num_frames / data[vid]['duration']
for ann in data[vid]['annotations']:
segment = ann['segment']
for fr in range(0, num_frames, 1):
# if fr / fps > [1] and fr / fps < ann[2]:
if fr / fps > float(segment[0]) and fr / fps < float(segment[1]):
label[class_dict[ann['label']], fr] = 1 # binary classification
dataset.append((vid, label, data[vid]['duration'], num_frames))
i += 1
return dataset
class Charades(data_utl.Dataset):
def __init__(self, split_file, split, root, mode, transforms=None, save_dir='', num=0):
self.data = make_dataset(split_file, split, root, mode, num_classes=20)
self.split_file = split_file
self.transforms = transforms
self.mode = mode
self.root = root
self.save_dir = save_dir
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
vid, label, dur, nf = self.data[index]
# if os.path.exists(os.path.join(self.save_dir, vid + '.npy')):
# return 0, 0, vid
if self.mode == 'rgb':
imgs = load_rgb_frames(self.root, vid, 0, nf)
else:
imgs = load_flow_frames(self.root, vid, 0, nf)
imgs = self.transforms(imgs)
return video_to_tensor(imgs), torch.from_numpy(label), vid
def __len__(self):
return len(self.data)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/bugcat9/pytorch-i3d.git
git@gitee.com:bugcat9/pytorch-i3d.git
bugcat9
pytorch-i3d
pytorch-i3d
master

搜索帮助