3 Star 0 Fork 0

mirrors_fendouai / Chinese-Text-Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
eval.py 3.69 KB
一键复制 编辑 原始数据 按行查看 历史
feiwang 提交于 2017-11-09 11:25 . init
#! /usr/bin/env python
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn
import csv
# Parameters
# ==================================================
# Data Parameters
tf.flags.DEFINE_string("positive_data_file", "./data/test_text/pos.txt", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/test_text/neg.txt", "Data source for the negative data.")
# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{}={}".format(attr.upper(), value))
print("")
# CHANGE THIS: Load data. Load your own data here
if FLAGS.eval_train:
x_raw, y_test = data_helpers.load_test_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
y_test = np.argmax(y_test, axis=1)
else:
x_raw = ["a masterpiece four years in the making", "everything is off."]
y_test = [1, 0]
# Map data into vocabulary
vocab_path = os.path.join(FLAGS.checkpoint_dir, "..", "vocab")
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
text_list=[]
for text in x_raw:
text_list.append(' '.join(text))
x_test = np.array(list(vocab_processor.transform(text_list)))
print("\nEvaluating...\n")
# Evaluation
# ==================================================
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1)
# Collect the predictions here
all_predictions = []
for x_test_batch in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
# Print accuracy if y_test is defined
if y_test is not None:
correct_predictions = float(sum(all_predictions == y_test))
print("Total number of test examples: {}".format(len(y_test)))
print("Accuracy: {:g}".format(correct_predictions/float(len(y_test))))
# Save the evaluation to a csv
predictions_human_readable = np.column_stack((np.array(x_raw), all_predictions))
out_path = os.path.join(FLAGS.checkpoint_dir, "..", "prediction.csv")
print("Saving evaluation to {0}".format(out_path))
with open(out_path, 'w') as f:
csv.writer(f).writerows(predictions_human_readable)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mirrors_fendouai/Chinese-Text-Classification.git
git@gitee.com:mirrors_fendouai/Chinese-Text-Classification.git
mirrors_fendouai
Chinese-Text-Classification
Chinese-Text-Classification
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891