1 Star 0 Fork 0

wptoux/triton-inference-server-image-preprocess

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
preprocess-model.ipynb 93.67 KB
一键复制 编辑 原始数据 按行查看 历史
Zhen Wang 提交于 5年前 . minor update
import os

import numpy as np
import tensorflow as tf

from PIL import Image
class Preprocessor(tf.Module):
    def __init__(self, height, width, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.height = height
        self.width = width
        self.mean = mean
        self.std = std
    
    def image_decoder(self, image_bytes):
        image = tf.io.decode_image(
            image_bytes,
            channels=3,
            dtype=tf.dtypes.uint8,
            name="decode_image",
            expand_animations=False
        )
        
        image = tf.image.resize_with_pad(
            image, target_height=self.height, target_width=self.width
        )
        image = tf.transpose(image, perm=[2, 0, 1])

        return image

    def normalize(self, tensor):
        mean = self.mean
        std = self.std
        
        tensor = tf.cast(tensor, tf.float32)
        tensor = tensor / 255.0
        mean = tf.constant(mean, dtype=tf.float32, name="mean")
        mean = tf.reshape(mean, [1, 3, 1, 1])
        std = tf.constant(std, dtype=tf.float32, name="std")
        std = tf.reshape(std, [1, 3, 1, 1])
        tensor = (tensor - mean) / std
        return tensor

    @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
    def __call__(self, raw_images):
        with tf.device("/cpu:0"):
            images = tf.map_fn(lambda x: self.image_decoder(x), raw_images, dtype=tf.float32)
        tensor = self.normalize(images)
        return tensor
bim = open('./test.jpg', 'rb').read()
bim1 = open('test1.png', 'rb').read()
h = 550
w = 800

prep = Preprocessor(h, w)
img_normed = prep([bim, bim1]).numpy()
r, g, b = img_normed[1]
r = r * 0.229 + 0.485
g = g * 0.224 + 0.456
b = b * 0.225 + 0.406

Image.fromarray(np.transpose((np.stack([r, g, b]) * 255).astype('uint8'), (1, 2, 0)))
tf.saved_model.save(prep, f'./example/preprocess_{h}x{w}/1/model.saved_model/')
INFO:tensorflow:Assets written to: ./example/preprocess_550x800/1/model.saved_model/assets
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wptoux/triton-inference-server-image-preprocess.git
git@gitee.com:wptoux/triton-inference-server-image-preprocess.git
wptoux
triton-inference-server-image-preprocess
triton-inference-server-image-preprocess
master

搜索帮助