Loading [MathJax]/jax/output/HTML-CSS/jax.js
3 Star 0 Fork 0

Gitee 极速下载/Fold-Tensor

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/tensorflow/fold
克隆/下载
sentiment.ipynb 39.92 KB
一键复制 编辑 原始数据 按行查看 历史

Sentiment Analysis with TreeLSTMs in TensorFlow Fold

The Stanford Sentiment Treebank is a corpus of ~10K one-sentence movie reviews from Rotten Tomatoes. The sentences have been parsed into binary trees with words at the leaves; every sub-tree has a label ranging from 0 (highly negative) to 4 (highly positive); 2 means neutral.

For example, (4 (2 Spiderman) (3 ROCKS)) is sentence with two words, corresponding a binary tree with three nodes. The label at the root, for the entire sentence, is 4 (highly positive). The label for the left child, a leaf corresponding to the word Spiderman, is 2 (neutral). The label for the right child, a leaf corresponding to the word ROCKS is 3 (moderately positive).

This notebook shows how to use TensorFlow Fold train a model on the treebank using binary TreeLSTMs and GloVe word embedding vectors, as described in the paper Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks by Tai et al. The original Torch source code for the model, provided by the authors, is available here.

The model illustrates three of the more advanced features of Fold, namely:

  1. Compositions to wire up blocks to form arbitrary directed acyclic graphs
  2. Forward Declarations to create recursive blocks
  3. Metrics to create models where the size of the output is not fixed, but varies as a function of the input data.
# boilerplate
import codecs
import functools
import os
import tempfile
import zipfile

from nltk.tokenize import sexpr
import numpy as np
from six.moves import urllib
import tensorflow as tf
sess = tf.InteractiveSession()
import tensorflow_fold as td

Get the data

Begin by fetching the word embedding vectors and treebank sentences.

data_dir = tempfile.mkdtemp()
print('saving files to %s' % data_dir)
saving files to /tmp/tmpPhKqpj
def download_and_unzip(url_base, zip_name, *file_names):
  zip_path = os.path.join(data_dir, zip_name)
  url = url_base + zip_name
  print('downloading %s to %s' % (url, zip_path))
  urllib.request.urlretrieve(url, zip_path)
  out_paths = []
  with zipfile.ZipFile(zip_path, 'r') as f:
    for file_name in file_names:
      print('extracting %s' % file_name)
      out_paths.append(f.extract(file_name, path=data_dir))
  return out_paths
  
full_glove_path, = download_and_unzip(
  'http://nlp.stanford.edu/data/', 'glove.840B.300d.zip',
  'glove.840B.300d.txt')
downloading http://nlp.stanford.edu/data/glove.840B.300d.zip to /tmp/tmpPhKqpj/glove.840B.300d.zip
extracting glove.840B.300d.txt
train_path, dev_path, test_path = download_and_unzip(
  'http://nlp.stanford.edu/sentiment/', 'trainDevTestTrees_PTB.zip', 
  'trees/train.txt', 'trees/dev.txt', 'trees/test.txt')
downloading http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip to /tmp/tmpPhKqpj/trainDevTestTrees_PTB.zip
extracting trees/train.txt
extracting trees/dev.txt
extracting trees/test.txt

Filter out words that don't appear in the dataset, since the full dataset is a bit large (5GB). This is purely a performance optimization and has no effect on the final results.

filtered_glove_path = os.path.join(data_dir, 'filtered_glove.txt')
def filter_glove():
  vocab = set()
  # Download the full set of unlabeled sentences separated by '|'.
  sentence_path, = download_and_unzip(
    'http://nlp.stanford.edu/~socherr/', 'stanfordSentimentTreebank.zip', 
    'stanfordSentimentTreebank/SOStr.txt')
  with codecs.open(sentence_path, encoding='utf-8') as f:
    for line in f:
      # Drop the trailing newline and strip backslashes. Split into words.
      vocab.update(line.strip().replace('\\', '').split('|'))
  nread = 0
  nwrote = 0
  with codecs.open(full_glove_path, encoding='utf-8') as f:
    with codecs.open(filtered_glove_path, 'w', encoding='utf-8') as out:
      for line in f:
        nread += 1
        line = line.strip()
        if not line: continue
        if line.split(u' ', 1)[0] in vocab:
          out.write(line + '\n')
          nwrote += 1
  print('read %s lines, wrote %s' % (nread, nwrote))
filter_glove()
downloading http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip to /tmp/tmpPhKqpj/stanfordSentimentTreebank.zip
extracting stanfordSentimentTreebank/SOStr.txt
read 2196018 lines, wrote 20725

Load the filtered word embeddings into a matrix and build an dict from words to indices into the matrix. Add a random embedding vector for out-of-vocabulary words.

def load_embeddings(embedding_path):
  """Loads embedings, returns weight matrix and dict from words to indices."""
  print('loading word embeddings from %s' % embedding_path)
  weight_vectors = []
  word_idx = {}
  with codecs.open(embedding_path, encoding='utf-8') as f:
    for line in f:
      word, vec = line.split(u' ', 1)
      word_idx[word] = len(weight_vectors)
      weight_vectors.append(np.array(vec.split(), dtype=np.float32))
  # Annoying implementation detail; '(' and ')' are replaced by '-LRB-' and
  # '-RRB-' respectively in the parse-trees.
  word_idx[u'-LRB-'] = word_idx.pop(u'(')
  word_idx[u'-RRB-'] = word_idx.pop(u')')
  # Random embedding vector for unknown words.
  weight_vectors.append(np.random.uniform(
      -0.05, 0.05, weight_vectors[0].shape).astype(np.float32))
  return np.stack(weight_vectors), word_idx
weight_matrix, word_idx = load_embeddings(filtered_glove_path)
loading word embeddings from /tmp/tmpPhKqpj/filtered_glove.txt

Finally, load the treebank data.

def load_trees(filename):
  with codecs.open(filename, encoding='utf-8') as f:
    # Drop the trailing newline and strip \s.
    trees = [line.strip().replace('\\', '') for line in f]
    print('loaded %s trees from %s' % (len(trees), filename))
    return trees
train_trees = load_trees(train_path)
dev_trees = load_trees(dev_path)
test_trees = load_trees(test_path)
loaded 8544 trees from /tmp/tmpPhKqpj/trees/train.txt
loaded 1101 trees from /tmp/tmpPhKqpj/trees/dev.txt
loaded 2210 trees from /tmp/tmpPhKqpj/trees/test.txt

Build the model

We want to compute a hidden state vector h for every node in the tree. The hidden state is the input to a linear layer with softmax output for predicting the sentiment label.

At the leaves of the tree, words are mapped to word-embedding vectors which serve as the input to a binary tree-LSTM with 0 for the previous states. At the internal nodes, the LSTM takes 0 as input, and previous states from its two children. More formally,

hword=TreeLSTM(Embedding(word),0,0) hleft,right=TreeLSTM(0,hleft,hright)

where TreeLSTM(x,hleft,hright) is a special kind of LSTM cell that takes two hidden states as inputs, and has a separate forget gate for each of them. Specifically, it is Tai et al. eqs. 9-14 with N=2. One modification here from Tai et al. is that instead of L2 weight regularization, we use recurrent droupout as described in the paper Recurrent Dropout without Memory Loss.

We can implement TreeLSTM by subclassing the TensorFlow BasicLSTMCell.

class BinaryTreeLSTMCell(tf.contrib.rnn.BasicLSTMCell):
  """LSTM with two state inputs.

  This is the model described in section 3.2 of 'Improved Semantic
  Representations From Tree-Structured Long Short-Term Memory
  Networks' <http://arxiv.org/pdf/1503.00075.pdf>, with recurrent
  dropout as described in 'Recurrent Dropout without Memory Loss'
  <http://arxiv.org/pdf/1603.05118.pdf>.
  """

  def __init__(self, num_units, keep_prob=1.0):
    """Initialize the cell.

    Args:
      num_units: int, The number of units in the LSTM cell.
      keep_prob: Keep probability for recurrent dropout.
    """
    super(BinaryTreeLSTMCell, self).__init__(num_units)
    self._keep_prob = keep_prob

  def __call__(self, inputs, state, scope=None):
    with tf.variable_scope(scope or type(self).__name__):
      lhs, rhs = state
      c0, h0 = lhs
      c1, h1 = rhs
      concat = tf.contrib.layers.linear(
          tf.concat([inputs, h0, h1], 1), 5 * self._num_units)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f0, f1, o = tf.split(value=concat, num_or_size_splits=5, axis=1)

      j = self._activation(j)
      if not isinstance(self._keep_prob, float) or self._keep_prob < 1:
        j = tf.nn.dropout(j, self._keep_prob)

      new_c = (c0 * tf.sigmoid(f0 + self._forget_bias) +
               c1 * tf.sigmoid(f1 + self._forget_bias) +
               tf.sigmoid(i) * j)
      new_h = self._activation(new_c) * tf.sigmoid(o)

      new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)

      return new_h, new_state

Use a placeholder for the dropout keep probability, with a default of 1 (for eval).

keep_prob_ph = tf.placeholder_with_default(1.0, [])

Create the LSTM cell for our model. In addition to recurrent dropout, apply dropout to inputs and outputs, using TF's build-in dropout wrapper. Put the LSTM cell inside of a td.ScopedLayer in order to manage variable scoping. This ensures that our LSTM's variables are encapsulated from the rest of the graph and get created exactly once.

lstm_num_units = 300  # Tai et al. used 150, but our regularization strategy is more effective
tree_lstm = td.ScopedLayer(
      tf.contrib.rnn.DropoutWrapper(
          BinaryTreeLSTMCell(lstm_num_units, keep_prob=keep_prob_ph),
          input_keep_prob=keep_prob_ph, output_keep_prob=keep_prob_ph),
      name_or_scope='tree_lstm')

Create the output layer using td.FC.

NUM_CLASSES = 5  # number of distinct sentiment labels
output_layer = td.FC(NUM_CLASSES, activation=None, name='output_layer')

Create the word embedding using td.Embedding. Note that the built-in Fold layers like Embedding and FC manage variable scoping automatically, so there is no need to put them inside scoped layers.

word_embedding = td.Embedding(
    *weight_matrix.shape, initializer=weight_matrix, name='word_embedding')

We now have layers that encapsulate all of the trainable variables for our model. The next step is to create the Fold blocks that define how inputs (s-expressions encoded as strings) get processed and used to make predictions. Naturally this requires a recursive model, which we handle in Fold using a forward declaration. The recursive step is to take a subtree (represented as a string) and convert it into a hidden state vector (the LSTM state), thus embedding it in a n-dimensional space (where here n=300).

embed_subtree = td.ForwardDeclaration(name='embed_subtree')

The core the model is a block that takes as input a list of tokens. The tokens will be either:

  • [word] - a leaf with a single word, the base-case for the recursion, or
  • [lhs, rhs] - an internal node consisting of a pair of sub-expressions

The outputs of the block will be a pair consisting of logits (the prediction) and the LSTM state.

def logits_and_state():
  """Creates a block that goes from tokens to (logits, state) tuples."""
  unknown_idx = len(word_idx)
  lookup_word = lambda word: word_idx.get(word, unknown_idx)
  
  word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
              td.Scalar('int32') >> word_embedding)

  pair2vec = (embed_subtree(), embed_subtree())

  # Trees are binary, so the tree layer takes two states as its input_state.
  zero_state = td.Zeros((tree_lstm.state_size,) * 2)
  # Input is a word vector.
  zero_inp = td.Zeros(word_embedding.output_type.shape[0])

  word_case = td.AllOf(word2vec, zero_state)
  pair_case = td.AllOf(zero_inp, pair2vec)

  tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)])

  return tree2vec >> tree_lstm >> (output_layer, td.Identity())

Note that we use the call operator () to create blocks that reference the embed_subtree forward declaration, for the recursive case.

Define a per-node loss function for training.

def tf_node_loss(logits, labels):
  return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)

Additionally calculate fine-grained and binary hits (i.e. un-normalized accuracy) for evals. Fine-grained accuracy is defined over all five class labels and will be calculated for all labels, whereas binary accuracy is defined of negative vs. positive classification and will not be calcluated for neutral labels.

def tf_fine_grained_hits(logits, labels):
  predictions = tf.cast(tf.argmax(logits, 1), tf.int32)
  return tf.cast(tf.equal(predictions, labels), tf.float64)
def tf_binary_hits(logits, labels):
  softmax = tf.nn.softmax(logits)
  binary_predictions = (softmax[:, 3] + softmax[:, 4]) > (softmax[:, 0] + softmax[:, 1])
  binary_labels = labels > 2
  return tf.cast(tf.equal(binary_predictions, binary_labels), tf.float64)

The td.Metric block provides a mechaism for accumulating results across sequential and recursive computations without having the thread them through explictly as return values. Metrics are wired up here inside of a td.Composition block, which allows us to explicitly specify the inputs of sub-blocks with calls to Block.reads() inside of a Composition.scope() context manager.

For training, we will sum the loss over all nodes. But for evals, we would like to separately calcluate accuracies for the root (i.e. entire sentences) to match the numbers presented in the literature. We also need to distinguish between neutral and non-neutral sentiment labels, because binary sentiment doesn't get calculated for neutral nodes.

This is easy to do by putting our block creation code for calculating metrics inside of a function and passing it indicators. Note that this needs to be done in Python-land, because we can't inspect the contents of a tensor inside of Fold (since it hasn't been run yet).

def add_metrics(is_root, is_neutral):
  """A block that adds metrics for loss and hits; output is the LSTM state."""
  c = td.Composition(
      name='predict(is_root=%s, is_neutral=%s)' % (is_root, is_neutral))
  with c.scope():
    # destructure the input; (labels, (logits, state))
    labels = c.input[0]
    logits = td.GetItem(0).reads(c.input[1])
    state = td.GetItem(1).reads(c.input[1])

    # calculate loss
    loss = td.Function(tf_node_loss)
    td.Metric('all_loss').reads(loss.reads(logits, labels))
    if is_root: td.Metric('root_loss').reads(loss)

    # calculate fine-grained hits
    hits = td.Function(tf_fine_grained_hits)
    td.Metric('all_hits').reads(hits.reads(logits, labels))
    if is_root: td.Metric('root_hits').reads(hits)

    # calculate binary hits, if the label is not neutral
    if not is_neutral:
      binary_hits = td.Function(tf_binary_hits).reads(logits, labels)
      td.Metric('all_binary_hits').reads(binary_hits)
      if is_root: td.Metric('root_binary_hits').reads(binary_hits)

    # output the state, which will be read by our by parent's LSTM cell
    c.output.reads(state)
  return c

Use NLTK to define a tokenize function to split S-exprs into left and right parts. We need this to run our logits_and_state() block since it expects to be passed a list of tokens and our raw input is strings.

def tokenize(s):
  label, phrase = s[1:-1].split(None, 1)
  return label, sexpr.sexpr_tokenize(phrase)

Try it out.

tokenize('(X Y)')
('X', ['Y'])
tokenize('(X Y Z)')
('X', ['Y Z'])

Embed trees (represented as strings) by tokenizing and piping (>>) to label_and_logits, distinguishing between neutral and non-neutral labels. We don't know here whether or not we are the root node (since this is a recursive computation), so that gets threaded through as an indicator.

def embed_tree(logits_and_state, is_root):
  """Creates a block that embeds trees; output is tree LSTM state."""
  return td.InputTransform(tokenize) >> td.OneOf(
      key_fn=lambda pair: pair[0] == '2',  # label 2 means neutral
      case_blocks=(add_metrics(is_root, is_neutral=False),
                   add_metrics(is_root, is_neutral=True)),
      pre_block=(td.Scalar('int32'), logits_and_state))

Put everything together and create our top-level (i.e. root) model. It is rather simple.

model = embed_tree(logits_and_state(), is_root=True)

Resolve the forward declaration for embedding subtrees (the non-root case) with a second call to embed_tree.

embed_subtree.resolve_to(embed_tree(logits_and_state(), is_root=False))

Compile the model.

compiler = td.Compiler.create(model)
print('input type: %s' % model.input_type)
print('output type: %s' % model.output_type)
input type: PyObjectType()
output type: TupleType(TensorType((300,), 'float32'), TensorType((300,), 'float32'))

Setup for training

Calculate means by summing the raw metrics.

metrics = {k: tf.reduce_mean(v) for k, v in compiler.metric_tensors.items()}

Magic numbers.

LEARNING_RATE = 0.05
KEEP_PROB = 0.75
BATCH_SIZE = 100
EPOCHS = 20
EMBEDDING_LEARNING_RATE_FACTOR = 0.1

Training with Adagrad.

train_feed_dict = {keep_prob_ph: KEEP_PROB}
loss = tf.reduce_sum(compiler.metric_tensors['all_loss'])
opt = tf.train.AdagradOptimizer(LEARNING_RATE)

Important detail from section 5.3 of [Tai et al.]((http://arxiv.org/pdf/1503.00075.pdf); downscale the gradients for the word embedding vectors 10x otherwise we overfit horribly.

grads_and_vars = opt.compute_gradients(loss)
found = 0
for i, (grad, var) in enumerate(grads_and_vars):
  if var == word_embedding.weights:
    found += 1
    grad = tf.scalar_mul(EMBEDDING_LEARNING_RATE_FACTOR, grad)
    grads_and_vars[i] = (grad, var)
assert found == 1  # internal consistency check
train = opt.apply_gradients(grads_and_vars)
saver = tf.train.Saver()
/usr/local/google/home/madscience/nuke/v3/local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py:91: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

The TF graph is now complete; initialize the variables.

sess.run(tf.global_variables_initializer())

Train the model

Start by defining a function that does a single step of training on a batch and returns the loss.

def train_step(batch):
  train_feed_dict[compiler.loom_input_tensor] = batch
  _, batch_loss = sess.run([train, loss], train_feed_dict)
  return batch_loss

Now similarly for an entire epoch of training.

def train_epoch(train_set):
  return sum(train_step(batch) for batch in td.group_by_batches(train_set, BATCH_SIZE))

Use Compiler.build_loom_inputs() to transform train_trees into individual loom inputs (i.e. wiring diagrams) that we can use to actually run the model.

train_set = compiler.build_loom_inputs(train_trees)

Use Compiler.build_feed_dict() to build a feed dictionary for validation on the dev set. This is marginally faster and more convenient than calling build_loom_inputs. We used build_loom_inputs on the train set so that we can shuffle the individual wiring diagrams into different batches for each epoch.

dev_feed_dict = compiler.build_feed_dict(dev_trees)

Define a function to do an eval on the dev set and pretty-print some stats, returning accuracy on the dev set.

def dev_eval(epoch, train_loss):
  dev_metrics = sess.run(metrics, dev_feed_dict)
  dev_loss = dev_metrics['all_loss']
  dev_accuracy = ['%s: %.2f' % (k, v * 100) for k, v in
                  sorted(dev_metrics.items()) if k.endswith('hits')]
  print('epoch:%4d, train_loss: %.3e, dev_loss_avg: %.3e, dev_accuracy:\n  [%s]'
        % (epoch, train_loss, dev_loss, ' '.join(dev_accuracy)))
  return dev_metrics['root_hits']

Run the main training loop, saving the model after each epoch if it has the best accuracy on the dev set. Use the td.epochs utility function to memoize the loom inputs and shuffle them after every epoch of training.

best_accuracy = 0.0
save_path = os.path.join(data_dir, 'sentiment_model')
for epoch, shuffled in enumerate(td.epochs(train_set, EPOCHS), 1):
  train_loss = train_epoch(shuffled)
  accuracy = dev_eval(epoch, train_loss)
  if accuracy > best_accuracy:
    best_accuracy = accuracy
    checkpoint_path = saver.save(sess, save_path, global_step=epoch)
    print('model saved in file: %s' % checkpoint_path)
epoch:   1, train_loss: 2.262e+05, dev_loss_avg: 5.253e-01, dev_accuracy:
  [all_binary_hits: 88.94 all_hits: 78.30 root_binary_hits: 82.00 root_hits: 40.51]
model saved in file: /tmp/tmpPhKqpj/sentiment_model-1
epoch:   2, train_loss: 1.590e+05, dev_loss_avg: 4.602e-01, dev_accuracy:
  [all_binary_hits: 90.41 all_hits: 81.00 root_binary_hits: 83.60 root_hits: 46.59]
model saved in file: /tmp/tmpPhKqpj/sentiment_model-2
epoch:   3, train_loss: 1.443e+05, dev_loss_avg: 4.371e-01, dev_accuracy:
  [all_binary_hits: 91.17 all_hits: 82.02 root_binary_hits: 85.21 root_hits: 48.68]
model saved in file: /tmp/tmpPhKqpj/sentiment_model-3
epoch:   4, train_loss: 1.357e+05, dev_loss_avg: 4.242e-01, dev_accuracy:
  [all_binary_hits: 91.63 all_hits: 82.45 root_binary_hits: 87.04 root_hits: 49.86]
model saved in file: /tmp/tmpPhKqpj/sentiment_model-4
epoch:   5, train_loss: 1.297e+05, dev_loss_avg: 4.190e-01, dev_accuracy:
  [all_binary_hits: 92.07 all_hits: 82.64 root_binary_hits: 88.19 root_hits: 51.50]
model saved in file: /tmp/tmpPhKqpj/sentiment_model-5
epoch:   6, train_loss: 1.246e+05, dev_loss_avg: 4.175e-01, dev_accuracy:
  [all_binary_hits: 91.77 all_hits: 82.52 root_binary_hits: 86.81 root_hits: 49.41]
epoch:   7, train_loss: 1.209e+05, dev_loss_avg: 4.164e-01, dev_accuracy:
  [all_binary_hits: 92.08 all_hits: 82.81 root_binary_hits: 87.61 root_hits: 50.41]
epoch:   8, train_loss: 1.172e+05, dev_loss_avg: 4.177e-01, dev_accuracy:
  [all_binary_hits: 91.92 all_hits: 82.88 root_binary_hits: 87.50 root_hits: 50.14]
epoch:   9, train_loss: 1.143e+05, dev_loss_avg: 4.158e-01, dev_accuracy:
  [all_binary_hits: 92.16 all_hits: 82.84 root_binary_hits: 87.73 root_hits: 49.86]
epoch:  10, train_loss: 1.120e+05, dev_loss_avg: 4.152e-01, dev_accuracy:
  [all_binary_hits: 92.27 all_hits: 82.91 root_binary_hits: 87.50 root_hits: 50.77]
epoch:  11, train_loss: 1.094e+05, dev_loss_avg: 4.179e-01, dev_accuracy:
  [all_binary_hits: 92.35 all_hits: 82.98 root_binary_hits: 88.76 root_hits: 50.14]
epoch:  12, train_loss: 1.074e+05, dev_loss_avg: 4.221e-01, dev_accuracy:
  [all_binary_hits: 91.96 all_hits: 83.03 root_binary_hits: 87.16 root_hits: 50.05]
epoch:  13, train_loss: 1.055e+05, dev_loss_avg: 4.224e-01, dev_accuracy:
  [all_binary_hits: 92.04 all_hits: 83.05 root_binary_hits: 87.50 root_hits: 50.05]
epoch:  14, train_loss: 1.039e+05, dev_loss_avg: 4.204e-01, dev_accuracy:
  [all_binary_hits: 92.38 all_hits: 83.01 root_binary_hits: 88.76 root_hits: 51.32]
epoch:  15, train_loss: 1.017e+05, dev_loss_avg: 4.229e-01, dev_accuracy:
  [all_binary_hits: 92.52 all_hits: 82.92 root_binary_hits: 88.53 root_hits: 49.68]
epoch:  16, train_loss: 1.004e+05, dev_loss_avg: 4.278e-01, dev_accuracy:
  [all_binary_hits: 92.57 all_hits: 83.00 root_binary_hits: 88.42 root_hits: 52.13]
model saved in file: /tmp/tmpPhKqpj/sentiment_model-16
epoch:  17, train_loss: 9.887e+04, dev_loss_avg: 4.316e-01, dev_accuracy:
  [all_binary_hits: 92.31 all_hits: 82.87 root_binary_hits: 87.73 root_hits: 51.04]
epoch:  18, train_loss: 9.742e+04, dev_loss_avg: 4.328e-01, dev_accuracy:
  [all_binary_hits: 92.28 all_hits: 82.90 root_binary_hits: 88.42 root_hits: 51.59]
epoch:  19, train_loss: 9.633e+04, dev_loss_avg: 4.338e-01, dev_accuracy:
  [all_binary_hits: 92.41 all_hits: 82.86 root_binary_hits: 88.53 root_hits: 51.68]
epoch:  20, train_loss: 9.474e+04, dev_loss_avg: 4.368e-01, dev_accuracy:
  [all_binary_hits: 92.23 all_hits: 82.90 root_binary_hits: 87.96 root_hits: 50.14]

The model starts to overfit pretty quickly even with dropout, as the LSTM begins to memorize the training set (which is rather small).

Evaluate the model

Restore the model from the last checkpoint, where we saw the best accuracy on the dev set.

saver.restore(sess, checkpoint_path)

See how we did.

test_results = sorted(sess.run(metrics, compiler.build_feed_dict(test_trees)).items())
print('    loss: [%s]' % ' '.join(
  '%s: %.3e' % (name.rsplit('_', 1)[0], v)
  for name, v in test_results if name.endswith('_loss')))
print('accuracy: [%s]' % ' '.join(
  '%s: %.2f' % (name.rsplit('_', 1)[0], v * 100)
  for name, v in test_results if name.endswith('_hits')))
    loss: [all: 4.276e-01 root: 1.121e+00]
accuracy: [all_binary: 92.37 all: 83.13 root_binary: 89.29 root: 51.90]

Not bad! See section 3.5.1 of our paper for discussion and a comparison of these results to the state of the art.

Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/Fold-Tensor.git
git@gitee.com:mirrors/Fold-Tensor.git
mirrors
Fold-Tensor
Fold-Tensor
master

搜索帮助