1 Star 6 Fork 0

shibwoen/stm32_tensorflow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 1.83 KB
一键复制 编辑 原始数据 按行查看 历史
尖头史 提交于 2021-11-20 00:15 +08:00 . shibowen
import csv
import numpy as np
import tensorflow as tf
import os
# 采用2号GPU,单GPU可注释
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
# 声明输入输出变量
x_d = np.zeros((80, 4), dtype='float32')
y_d = np.zeros((80, 1), dtype='float32')
# 导入训练集
f = csv.reader(open('dataset/train_data.csv', 'r'))
for i in f:
x_d[int(i[0])][0] = float(i[1])
x_d[int(i[0])][1] = float(i[2])
x_d[int(i[0])][2] = float(i[3])
x_d[int(i[0])][3] = float(i[4])
y_d[int(i[0])] = float(i[6])
# 训练模型搭建
# 输入层
x = tf.compat.v1.placeholder(tf.float32, shape=[None, 4], name='x')
y = tf.compat.v1.placeholder(tf.float32, shape=[None, 1], name='y')
# 隐藏层1
w1 = tf.Variable(tf.truncated_normal([4, 2]), name='w1')
b1 = tf.Variable(tf.truncated_normal([2], 0.1), name='b1')
l1 = tf.sigmoid(tf.matmul(x, w1) + b1, name='l1')
# 隐藏层2
w = tf.Variable(tf.truncated_normal([2, 1]), name='w')
b = tf.Variable(tf.truncated_normal([1], 0.1), name='b')
# 输出层
o = tf.sigmoid(tf.matmul(l1, w) + b, name='o')
loss = tf.reduce_mean(tf.square(o - y))
train = tf.train.GradientDescentOptimizer(0.9).minimize(loss)
# 初始化
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=4)
# 创建
with tf.Session() as sess:
writer = tf.summary.FileWriter("logs/", sess.graph)
sess.run(init)
max_step = 9000
for i in range(max_step + 1):
sess.run(train, feed_dict={x: x_d, y: y_d})
cost = sess.run(loss, feed_dict={x: x_d, y: y_d})
if i % 1000 == 0:
print("------------------------------------------------------")
# 保存模型
saver.save(sess, "model/ckpt/model", global_step=i)
print("------------------------------------------------------")
print('step: ' + str(i) + ' loss:' + "{:.3f}".format(cost))
print('训练结束')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/shibwoen/stm32_tensorflow.git
git@gitee.com:shibwoen/stm32_tensorflow.git
shibwoen
stm32_tensorflow
stm32_tensorflow
master

搜索帮助