1 Star 0 Fork 0

李锦峰/CLIP4Clip

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
data_dataloaders.py 8.61 KB
一键复制 编辑 原始数据 按行查看 历史
import torch
from torch.utils.data import DataLoader
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_DataLoader
from dataloaders.dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader
from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader
from dataloaders.dataloader_lsmdc_retrieval import LSMDC_DataLoader
from dataloaders.dataloader_activitynet_retrieval import ActivityNet_DataLoader
from dataloaders.dataloader_didemo_retrieval import DiDeMo_DataLoader
def dataloader_msrvtt_train(args, tokenizer):
msrvtt_dataset = MSRVTT_TrainDataLoader(
csv_path=args.train_csv,
json_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
unfold_sentences=args.expand_msrvtt_sentences,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset)
dataloader = DataLoader(
msrvtt_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(msrvtt_dataset), train_sampler
def dataloader_msrvtt_test(args, tokenizer, subset="test"):
msrvtt_testset = MSRVTT_DataLoader(
csv_path=args.val_csv,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_msrvtt = DataLoader(
msrvtt_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(msrvtt_testset)
def dataloader_msvd_train(args, tokenizer):
msvd_dataset = MSVD_DataLoader(
subset="train",
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset)
dataloader = DataLoader(
msvd_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(msvd_dataset), train_sampler
def dataloader_msvd_test(args, tokenizer, subset="test"):
msvd_testset = MSVD_DataLoader(
subset=subset,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_msrvtt = DataLoader(
msvd_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(msvd_testset)
def dataloader_lsmdc_train(args, tokenizer):
lsmdc_dataset = LSMDC_DataLoader(
subset="train",
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(lsmdc_dataset)
dataloader = DataLoader(
lsmdc_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(lsmdc_dataset), train_sampler
def dataloader_lsmdc_test(args, tokenizer, subset="test"):
lsmdc_testset = LSMDC_DataLoader(
subset=subset,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_msrvtt = DataLoader(
lsmdc_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(lsmdc_testset)
def dataloader_activity_train(args, tokenizer):
activity_dataset = ActivityNet_DataLoader(
subset="train",
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(activity_dataset)
dataloader = DataLoader(
activity_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(activity_dataset), train_sampler
def dataloader_activity_test(args, tokenizer, subset="test"):
activity_testset = ActivityNet_DataLoader(
subset=subset,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_msrvtt = DataLoader(
activity_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(activity_testset)
def dataloader_didemo_train(args, tokenizer):
didemo_dataset = DiDeMo_DataLoader(
subset="train",
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(didemo_dataset)
dataloader = DataLoader(
didemo_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(didemo_dataset), train_sampler
def dataloader_didemo_test(args, tokenizer, subset="test"):
didemo_testset = DiDeMo_DataLoader(
subset=subset,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_didemo = DataLoader(
didemo_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_didemo, len(didemo_testset)
DATALOADER_DICT = {}
DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test, "test":None}
DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test}
DATALOADER_DICT["lsmdc"] = {"train":dataloader_lsmdc_train, "val":dataloader_lsmdc_test, "test":dataloader_lsmdc_test}
DATALOADER_DICT["activity"] = {"train":dataloader_activity_train, "val":dataloader_activity_test, "test":None}
DATALOADER_DICT["didemo"] = {"train":dataloader_didemo_train, "val":dataloader_didemo_test, "test":dataloader_didemo_test}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/li_jinfeng111/CLIP4Clip.git
git@gitee.com:li_jinfeng111/CLIP4Clip.git
li_jinfeng111
CLIP4Clip
CLIP4Clip
master

搜索帮助

371d5123 14472233 46e8bd33 14472233