3 Star 9 Fork 6

skywalk / PaddleRobustVideoMatting

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
inference_utils.py 8.09 KB
一键复制 编辑 原始数据 按行查看 历史
skywalk 提交于 2022-05-14 12:46 . update inference_utils.py.
# 飞桨实现PIL to image
import paddle
import PIL
import numbers
import numpy as np
from PIL import Image
from paddle.vision.transforms import ToTensor
from paddle.vision.transforms import BaseTransform
from paddle.vision.transforms import functional as F
class ToPILImage(BaseTransform):
def __init__(self, mode=None, keys=None):
super(ToPILImage, self).__init__(keys)
# super(ToPILImage, self).__init__(keys)
self.mode=mode
def _apply_image(self, pic):
"""
Args:
pic (Tensor|np.ndarray): Image to be converted to PIL Image.
Returns:
PIL: Converted image.
"""
if not (isinstance(pic, paddle.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(
type(pic)))
elif isinstance(pic, paddle.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(
'pic should be 2/3 dimensional. Got {} dimensions.'.format(
pic.ndimension()))
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(
'pic should be 2/3 dimensional. Got {} dimensions.'.format(
pic.ndim))
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
npimg = pic
# 修改的图片的颜色模式: 可以使用图片对象的 conver( ) 方法修改图片模式 。可以取值范围: 1、L、P、RGB、RGBA、CMYK、YCbCr、I、F。返回的也是一个图片副本。
if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
).dtype) and self.mode != 'F':
# if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
# ).dtype) :
# pic = pic.multiply(paddle.to_tensor(255.)).byte()
pic = pic*255
pic = pic.astype('uint8')
if isinstance(pic, paddle.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError(
'Input pic must be a paddle.Tensor or NumPy ndarray, ' +
'not {}'.format(type(npimg)))
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = 'L'
elif npimg.dtype == np.int16:
expected_mode = 'I;16'
elif npimg.dtype == np.int32:
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if self.mode is not None and self.mode != expected_mode:
raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(self.mode, np.dtype, expected_mode))
self.mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if self.mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes))
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'LA'
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if self.mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes))
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if self.mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes))
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGB'
if self.mode is None:
raise TypeError('Input type {} is not supported'.format(
npimg.dtype))
return Image.fromarray(npimg, mode=self.mode)
# RobustVideoMatting/inference_utils.py
# 后面会用到这四个函数:from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
import av
import os
import pims
import numpy as np
from paddle.io import Dataset # 据说这个跟storch功能一致
# from paddle.vision.transforms.functional import to_pil_image
to_pil_image = ToPILImage()
from PIL import Image
# @property创建只读属性
class VideoReader(Dataset):
def __init__(self, path, transform=None):
self.video = pims.PyAVVideoReader(path)
self.rate = self.video.frame_rate
self.transform = transform
@property
def frame_rate(self):
return self.rate
def __len__(self):
return len(self.video)
def __getitem__(self, idx):
frame = self.video[idx]
frame = Image.fromarray(np.asarray(frame))
if self.transform is not None:
frame = self.transform(frame)
return frame
class VideoWriter:
def __init__(self, path, frame_rate, bit_rate=1000000):
self.container = av.open(path, mode='w')
self.stream = self.container.add_stream('h264', rate=round(frame_rate))
self.stream.pix_fmt = 'yuv420p'
self.stream.bit_rate = bit_rate
def write(self, frames):
# frames: [T, C, H, W]
# print("==frames: [T, C, H, W]", frames.shape, frames[0,0,0,0])
self.stream.width = frames.shape[3] #shape size(3)
self.stream.height = frames.shape[2]
if frames.shape[1] == 1:
# print("==write frames before repeat", frames.shape)
# frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB repeat对应飞桨什么呢?
frames = frames.tile([1, 3, 1, 1])
# print("==write frames after repeat", frames.shape)
# 拆分下面的长句,以便单步执行和代码替换
x=frames*255
# print("==x=frames*255", x[0,0,0,0])
x=x.transpose([0,2,3,1])
x=x.astype('uint8')
# print("==x.astype('uint8')", x)
# print("==write", x.shape)
x=x.numpy()
# frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
frames = x
# print("==frames", frames.shape, frames[0,0,0,0])
for t in range(frames.shape[0]):
frame = frames[t]
# print('=frame', frame.shape, type(frame), frame)
frame = av.VideoFrame.from_ndarray(frame, format='rgb24')
self.container.mux(self.stream.encode(frame))
def close(self):
self.container.mux(self.stream.encode())
self.container.close()
class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.files = sorted(os.listdir(path))
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
if self.transform is not None:
return self.transform(img)
return img
class ImageSequenceWriter:
def __init__(self, path, extension='jpg'):
self.path = path
self.extension = extension
self.counter = 0
os.makedirs(path, exist_ok=True)
def write(self, frames):
# frames: [T, C, H, W]
for t in range(frames.shape[0]):
to_pil_image(frames[t]).save(os.path.join(
self.path, str(self.counter).zfill(4) + '.' + self.extension))
self.counter += 1
def close(self):
pass
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/skywalk/paddlerobustvideomatting.git
git@gitee.com:skywalk/paddlerobustvideomatting.git
skywalk
paddlerobustvideomatting
PaddleRobustVideoMatting
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891