1 Star 1 Fork 6

张觉非 / 计算图框架

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test_sobel.py 2.88 KB
一键复制 编辑 原始数据 按行查看 历史
张觉非 提交于 2019-07-25 19:30 . 整理
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 24 10:37:12 2019
@author: zhangjuefei
"""
import skimage, os
import numpy as np
from node import *
from optimizer import *
from scipy.ndimage.filters import convolve
import matplotlib.pyplot as plt
pic_root = "E:/train_pic/fruit/train/lena"
os.chdir(pic_root)
filter = np.mat([
[.019, .100, .019],
[.1, .531, .100],
[.019, .100, .019]])
# Sobel 滤波器
# filter = np.mat([
# [1., 0., -1.],
# [2., 0., -2.],
# [1., 0., -1.]])
# 图像尺寸
w, h = 128, 128
# 搭建计算图
input_img = Variable((w, h), init=False, trainable=False) # 输入图像占位变量
target_img = Variable((w, h), init=False, trainable=False) # target 图像占位变量
kernel = Variable((3, 3), init=True, trainable=True) # 被训练的卷积核
# 对输入图像施加滤波
conv_img = Convolve(input_img, kernel)
# 将输入图像和 target 图像展成向量
target_img_flat = Reshape(target_img, shape=(w * h, 1))
conv_img_flat = Reshape(conv_img, shape=(w * h, 1))
# 常数(矩阵)[[-1]]
minus = Variable((1, 1), init=False, trainable=False)
minus.set_value(np.mat(-1))
# 常数(矩阵):图像总像素数的倒数
n = Variable((1, 1), init=False, trainable=False)
n.set_value(np.mat(1.0 / (w * h)))
# 损失值:输出图像与 target 图像的平均像素平方误差
loss = MatMul(Dot(
Add(target_img_flat, MatMul(conv_img_flat, minus)),
Add(target_img_flat, MatMul(conv_img_flat, minus))
), n)
# RMSProp 优化器
optimizer = Adam(default_graph, loss, 0.06, batch_size=1)
# 读取 lena 图,将 rgb 图像转成单通道灰度图,并 resize 成指定大小
img = skimage.transform.resize(
skimage.color.rgb2gray(
skimage.io.imread("lena.png")
),
(w, h)
)
# 制作目标图像:对原图像施加 Sobel 滤波器
sobel = np.mat([[1,0,-1],[2,0,-2],[1,0,-1]])
target = np.zeros((w, h))
convolve(input=img, output=target, weights=filter, mode="constant", cval=0.0)
# 保存原图和经过 Sobel 滤波的图像
skimage.io.imsave("origin.png", np.minimum(np.maximum(img, 0.0), 1.0))
skimage.io.imsave("target.png", np.minimum(np.maximum(target, 0.0), 1.0))
# 以原图为输入,以经过 Sobel 滤波的图像为 target
input_img.set_value(np.mat(img))
target_img.set_value(np.mat(target))
i = 0
for e in range(1000):
# 一次迭代
optimizer.one_step()
# 计算损失值
loss.forward()
print("pic:{:s},loss:{:.6f}".format(p, loss.value[0, 0]))
print(kernel.value)
# 保存当前卷积核对输入图像做滤波的结果
fname = "{:d}.png".format(i)
if os.path.exists(fname):
os.remove(fname)
skimage.io.imsave(fname, np.minimum(np.maximum(conv_img.value, 0.0), 1.0))
i += 1
Python
1
https://gitee.com/zhangjuefei/computing_graph_demo.git
git@gitee.com:zhangjuefei/computing_graph_demo.git
zhangjuefei
computing_graph_demo
计算图框架
master

搜索帮助