1 Star 0 Fork 1

北部湾的落日 / imageProcessing

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
study_2.py 2.23 KB
一键复制 编辑 原始数据 按行查看 历史
北部湾的落日 提交于 2018-05-09 09:58 . Initial commit
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def add_layer(inputs, in_size, out_size, activation_function=None): # activation_function=None线性函数
Weights = tf.Variable(tf.random_normal([in_size, out_size])) # Weight中都是随机变量
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) # biases推荐初始值不为0
Wx_plus_b = tf.matmul(inputs, Weights) + biases # inputs*Weight+biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
# 创建数据x_data,y_data
x_data = np.linspace(-1, 1, 300)[:, np.newaxis] # [-1,1]区间,300个单位,np.newaxis增加维度
noise = np.random.normal(0, 0.05, x_data.shape) # 噪点
y_data = np.square(x_data) - 0.5 + noise
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# 三层神经,输入层(1个神经元),隐藏层(10神经元),输出层(1个神经元)
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu) # 输入层
prediction = add_layer(l1, 10, 1, activation_function=None) # 隐藏层
# predition值与y_data差别
loss = tf.reduce_mean(
tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1])) # square()平方,sum()求和,mean()平均值
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss) # 反向传播,0.1学习效率,minimize(loss)减小loss误差
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init) # 先执行init
# 可视化
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.scatter(x_data, y_data)
plt.ion() # 不让show() block
plt.show()
# 训练1k次
for i in range(1000):
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
try:
ax.lines.remove(lines[0]) # lines建一个抹除一个
except Exception:
pass
# print(sess.run(loss,feed_dict={xs:x_data,ys:y_data})) #输出loss值
# 可视化
prediction_value = sess.run(prediction, feed_dict={xs: x_data, ys: y_data})
lines = ax.plot(x_data, prediction_value, 'r-', lw=5) # x_data X轴,prediction_value Y轴,'r-'红线,lw=5线宽5
plt.pause(0.1) # 暂停0.1秒
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cbyh1313/imageProcessing.git
git@gitee.com:cbyh1313/imageProcessing.git
cbyh1313
imageProcessing
imageProcessing
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891