代码拉取完成,页面将自动刷新
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt
batch_size = 800
train_data = dsets.CIFAR10(root='./cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.RandomErasing(p=0.5, scale=(0.02, 0.4), ratio=(0.33, 3.0))]))
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_data = dsets.CIFAR10(root='./cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),]))
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)
def image_show(data_loader,n):
tmp = iter(data_loader)
images,labels = tmp.next()
images = images.numpy()
for i in range(n):
image = np.transpose(images[i],[1,2,0])
plt.imshow(image)
plt.show()
image_show(train_loader,10)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。