The Stanford Sentiment Treebank is a corpus of ~10K one-sentence movie reviews from Rotten Tomatoes. The sentences have been parsed into binary trees with words at the leaves; every sub-tree has a label ranging from 0 (highly negative) to 4 (highly positive); 2 means neutral.
For example, (4 (2 Spiderman) (3 ROCKS))
is sentence with two words, corresponding a binary tree with three nodes. The label at the root, for the entire sentence, is 4
(highly positive). The label for the left child, a leaf corresponding to the word Spiderman
, is 2
(neutral). The label for the right child, a leaf corresponding to the word ROCKS
is 3
(moderately positive).
This notebook shows how to use TensorFlow Fold train a model on the treebank using binary TreeLSTMs and GloVe word embedding vectors, as described in the paper Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks by Tai et al. The original Torch source code for the model, provided by the authors, is available here.
The model illustrates three of the more advanced features of Fold, namely:
# boilerplate
import codecs
import functools
import os
import tempfile
import zipfile
from nltk.tokenize import sexpr
import numpy as np
from six.moves import urllib
import tensorflow as tf
sess = tf.InteractiveSession()
import tensorflow_fold as td
data_dir = tempfile.mkdtemp()
print('saving files to %s' % data_dir)
saving files to /tmp/tmpPhKqpj
def download_and_unzip(url_base, zip_name, *file_names):
zip_path = os.path.join(data_dir, zip_name)
url = url_base + zip_name
print('downloading %s to %s' % (url, zip_path))
urllib.request.urlretrieve(url, zip_path)
out_paths = []
with zipfile.ZipFile(zip_path, 'r') as f:
for file_name in file_names:
print('extracting %s' % file_name)
out_paths.append(f.extract(file_name, path=data_dir))
return out_paths
full_glove_path, = download_and_unzip(
'http://nlp.stanford.edu/data/', 'glove.840B.300d.zip',
'glove.840B.300d.txt')
downloading http://nlp.stanford.edu/data/glove.840B.300d.zip to /tmp/tmpPhKqpj/glove.840B.300d.zip extracting glove.840B.300d.txt
train_path, dev_path, test_path = download_and_unzip(
'http://nlp.stanford.edu/sentiment/', 'trainDevTestTrees_PTB.zip',
'trees/train.txt', 'trees/dev.txt', 'trees/test.txt')
downloading http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip to /tmp/tmpPhKqpj/trainDevTestTrees_PTB.zip extracting trees/train.txt extracting trees/dev.txt extracting trees/test.txt
Filter out words that don't appear in the dataset, since the full dataset is a bit large (5GB). This is purely a performance optimization and has no effect on the final results.
filtered_glove_path = os.path.join(data_dir, 'filtered_glove.txt')
def filter_glove():
vocab = set()
# Download the full set of unlabeled sentences separated by '|'.
sentence_path, = download_and_unzip(
'http://nlp.stanford.edu/~socherr/', 'stanfordSentimentTreebank.zip',
'stanfordSentimentTreebank/SOStr.txt')
with codecs.open(sentence_path, encoding='utf-8') as f:
for line in f:
# Drop the trailing newline and strip backslashes. Split into words.
vocab.update(line.strip().replace('\\', '').split('|'))
nread = 0
nwrote = 0
with codecs.open(full_glove_path, encoding='utf-8') as f:
with codecs.open(filtered_glove_path, 'w', encoding='utf-8') as out:
for line in f:
nread += 1
line = line.strip()
if not line: continue
if line.split(u' ', 1)[0] in vocab:
out.write(line + '\n')
nwrote += 1
print('read %s lines, wrote %s' % (nread, nwrote))
filter_glove()
downloading http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip to /tmp/tmpPhKqpj/stanfordSentimentTreebank.zip extracting stanfordSentimentTreebank/SOStr.txt read 2196018 lines, wrote 20725
Load the filtered word embeddings into a matrix and build an dict from words to indices into the matrix. Add a random embedding vector for out-of-vocabulary words.
def load_embeddings(embedding_path):
"""Loads embedings, returns weight matrix and dict from words to indices."""
print('loading word embeddings from %s' % embedding_path)
weight_vectors = []
word_idx = {}
with codecs.open(embedding_path, encoding='utf-8') as f:
for line in f:
word, vec = line.split(u' ', 1)
word_idx[word] = len(weight_vectors)
weight_vectors.append(np.array(vec.split(), dtype=np.float32))
# Annoying implementation detail; '(' and ')' are replaced by '-LRB-' and
# '-RRB-' respectively in the parse-trees.
word_idx[u'-LRB-'] = word_idx.pop(u'(')
word_idx[u'-RRB-'] = word_idx.pop(u')')
# Random embedding vector for unknown words.
weight_vectors.append(np.random.uniform(
-0.05, 0.05, weight_vectors[0].shape).astype(np.float32))
return np.stack(weight_vectors), word_idx
weight_matrix, word_idx = load_embeddings(filtered_glove_path)
loading word embeddings from /tmp/tmpPhKqpj/filtered_glove.txt
Finally, load the treebank data.
def load_trees(filename):
with codecs.open(filename, encoding='utf-8') as f:
# Drop the trailing newline and strip \s.
trees = [line.strip().replace('\\', '') for line in f]
print('loaded %s trees from %s' % (len(trees), filename))
return trees
train_trees = load_trees(train_path)
dev_trees = load_trees(dev_path)
test_trees = load_trees(test_path)
loaded 8544 trees from /tmp/tmpPhKqpj/trees/train.txt loaded 1101 trees from /tmp/tmpPhKqpj/trees/dev.txt loaded 2210 trees from /tmp/tmpPhKqpj/trees/test.txt
We want to compute a hidden state vector h for every node in the tree. The hidden state is the input to a linear layer with softmax output for predicting the sentiment label.
At the leaves of the tree, words are mapped to word-embedding vectors which serve as the input to a binary tree-LSTM with 0 for the previous states. At the internal nodes, the LSTM takes 0 as input, and previous states from its two children. More formally,
hword=TreeLSTM(Embedding(word),0,0) hleft,right=TreeLSTM(0,hleft,hright)
where TreeLSTM(x,hleft,hright) is a special kind of LSTM cell that takes two hidden states as inputs, and has a separate forget gate for each of them. Specifically, it is Tai et al. eqs. 9-14 with N=2. One modification here from Tai et al. is that instead of L2 weight regularization, we use recurrent droupout as described in the paper Recurrent Dropout without Memory Loss.
We can implement TreeLSTM by subclassing the TensorFlow BasicLSTMCell
.
class BinaryTreeLSTMCell(tf.contrib.rnn.BasicLSTMCell):
"""LSTM with two state inputs.
This is the model described in section 3.2 of 'Improved Semantic
Representations From Tree-Structured Long Short-Term Memory
Networks' <http://arxiv.org/pdf/1503.00075.pdf>, with recurrent
dropout as described in 'Recurrent Dropout without Memory Loss'
<http://arxiv.org/pdf/1603.05118.pdf>.
"""
def __init__(self, num_units, keep_prob=1.0):
"""Initialize the cell.
Args:
num_units: int, The number of units in the LSTM cell.
keep_prob: Keep probability for recurrent dropout.
"""
super(BinaryTreeLSTMCell, self).__init__(num_units)
self._keep_prob = keep_prob
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
lhs, rhs = state
c0, h0 = lhs
c1, h1 = rhs
concat = tf.contrib.layers.linear(
tf.concat([inputs, h0, h1], 1), 5 * self._num_units)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f0, f1, o = tf.split(value=concat, num_or_size_splits=5, axis=1)
j = self._activation(j)
if not isinstance(self._keep_prob, float) or self._keep_prob < 1:
j = tf.nn.dropout(j, self._keep_prob)
new_c = (c0 * tf.sigmoid(f0 + self._forget_bias) +
c1 * tf.sigmoid(f1 + self._forget_bias) +
tf.sigmoid(i) * j)
new_h = self._activation(new_c) * tf.sigmoid(o)
new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
return new_h, new_state
Use a placeholder for the dropout keep probability, with a default of 1 (for eval).
keep_prob_ph = tf.placeholder_with_default(1.0, [])
Create the LSTM cell for our model. In addition to recurrent dropout, apply dropout to inputs and outputs, using TF's build-in dropout wrapper. Put the LSTM cell inside of a td.ScopedLayer
in order to manage variable scoping. This ensures that our LSTM's variables are encapsulated from the rest of the graph and get created exactly once.
lstm_num_units = 300 # Tai et al. used 150, but our regularization strategy is more effective
tree_lstm = td.ScopedLayer(
tf.contrib.rnn.DropoutWrapper(
BinaryTreeLSTMCell(lstm_num_units, keep_prob=keep_prob_ph),
input_keep_prob=keep_prob_ph, output_keep_prob=keep_prob_ph),
name_or_scope='tree_lstm')
Create the output layer using td.FC
.
NUM_CLASSES = 5 # number of distinct sentiment labels
output_layer = td.FC(NUM_CLASSES, activation=None, name='output_layer')
Create the word embedding using td.Embedding
. Note that the built-in Fold layers like Embedding
and FC
manage variable scoping automatically, so there is no need to put them inside scoped layers.
word_embedding = td.Embedding(
*weight_matrix.shape, initializer=weight_matrix, name='word_embedding')
We now have layers that encapsulate all of the trainable variables for our model. The next step is to create the Fold blocks that define how inputs (s-expressions encoded as strings) get processed and used to make predictions. Naturally this requires a recursive model, which we handle in Fold using a forward declaration. The recursive step is to take a subtree (represented as a string) and convert it into a hidden state vector (the LSTM state), thus embedding it in a n-dimensional space (where here n=300).
embed_subtree = td.ForwardDeclaration(name='embed_subtree')
The core the model is a block that takes as input a list of tokens. The tokens will be either:
[word]
- a leaf with a single word, the base-case for the recursion, or[lhs, rhs]
- an internal node consisting of a pair of sub-expressionsThe outputs of the block will be a pair consisting of logits (the prediction) and the LSTM state.
def logits_and_state():
"""Creates a block that goes from tokens to (logits, state) tuples."""
unknown_idx = len(word_idx)
lookup_word = lambda word: word_idx.get(word, unknown_idx)
word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
td.Scalar('int32') >> word_embedding)
pair2vec = (embed_subtree(), embed_subtree())
# Trees are binary, so the tree layer takes two states as its input_state.
zero_state = td.Zeros((tree_lstm.state_size,) * 2)
# Input is a word vector.
zero_inp = td.Zeros(word_embedding.output_type.shape[0])
word_case = td.AllOf(word2vec, zero_state)
pair_case = td.AllOf(zero_inp, pair2vec)
tree2vec = td.OneOf(len, [(1, word_case), (2, pair_case)])
return tree2vec >> tree_lstm >> (output_layer, td.Identity())
Note that we use the call operator ()
to create blocks that reference the embed_subtree
forward declaration, for the recursive case.
Define a per-node loss function for training.
def tf_node_loss(logits, labels):
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
Additionally calculate fine-grained and binary hits (i.e. un-normalized accuracy) for evals. Fine-grained accuracy is defined over all five class labels and will be calculated for all labels, whereas binary accuracy is defined of negative vs. positive classification and will not be calcluated for neutral labels.
def tf_fine_grained_hits(logits, labels):
predictions = tf.cast(tf.argmax(logits, 1), tf.int32)
return tf.cast(tf.equal(predictions, labels), tf.float64)
def tf_binary_hits(logits, labels):
softmax = tf.nn.softmax(logits)
binary_predictions = (softmax[:, 3] + softmax[:, 4]) > (softmax[:, 0] + softmax[:, 1])
binary_labels = labels > 2
return tf.cast(tf.equal(binary_predictions, binary_labels), tf.float64)
The td.Metric
block provides a mechaism for accumulating results across sequential and recursive computations without having the thread them through explictly as return values. Metrics are wired up here inside of a td.Composition
block, which allows us to explicitly specify the inputs of sub-blocks with calls to Block.reads()
inside of a Composition.scope()
context manager.
For training, we will sum the loss over all nodes. But for evals, we would like to separately calcluate accuracies for the root (i.e. entire sentences) to match the numbers presented in the literature. We also need to distinguish between neutral and non-neutral sentiment labels, because binary sentiment doesn't get calculated for neutral nodes.
This is easy to do by putting our block creation code for calculating metrics inside of a function and passing it indicators. Note that this needs to be done in Python-land, because we can't inspect the contents of a tensor inside of Fold (since it hasn't been run yet).
def add_metrics(is_root, is_neutral):
"""A block that adds metrics for loss and hits; output is the LSTM state."""
c = td.Composition(
name='predict(is_root=%s, is_neutral=%s)' % (is_root, is_neutral))
with c.scope():
# destructure the input; (labels, (logits, state))
labels = c.input[0]
logits = td.GetItem(0).reads(c.input[1])
state = td.GetItem(1).reads(c.input[1])
# calculate loss
loss = td.Function(tf_node_loss)
td.Metric('all_loss').reads(loss.reads(logits, labels))
if is_root: td.Metric('root_loss').reads(loss)
# calculate fine-grained hits
hits = td.Function(tf_fine_grained_hits)
td.Metric('all_hits').reads(hits.reads(logits, labels))
if is_root: td.Metric('root_hits').reads(hits)
# calculate binary hits, if the label is not neutral
if not is_neutral:
binary_hits = td.Function(tf_binary_hits).reads(logits, labels)
td.Metric('all_binary_hits').reads(binary_hits)
if is_root: td.Metric('root_binary_hits').reads(binary_hits)
# output the state, which will be read by our by parent's LSTM cell
c.output.reads(state)
return c
Use NLTK to define a tokenize
function to split S-exprs into left and right parts. We need this to run our logits_and_state()
block since it expects to be passed a list of tokens and our raw input is strings.
def tokenize(s):
label, phrase = s[1:-1].split(None, 1)
return label, sexpr.sexpr_tokenize(phrase)
Try it out.
tokenize('(X Y)')
('X', ['Y'])
tokenize('(X Y Z)')
('X', ['Y Z'])
Embed trees (represented as strings) by tokenizing and piping (>>
) to label_and_logits
, distinguishing between neutral and non-neutral labels. We don't know here whether or not we are the root node (since this is a recursive computation), so that gets threaded through as an indicator.
def embed_tree(logits_and_state, is_root):
"""Creates a block that embeds trees; output is tree LSTM state."""
return td.InputTransform(tokenize) >> td.OneOf(
key_fn=lambda pair: pair[0] == '2', # label 2 means neutral
case_blocks=(add_metrics(is_root, is_neutral=False),
add_metrics(is_root, is_neutral=True)),
pre_block=(td.Scalar('int32'), logits_and_state))
Put everything together and create our top-level (i.e. root) model. It is rather simple.
model = embed_tree(logits_and_state(), is_root=True)
Resolve the forward declaration for embedding subtrees (the non-root case) with a second call to embed_tree
.
embed_subtree.resolve_to(embed_tree(logits_and_state(), is_root=False))
Compile the model.
compiler = td.Compiler.create(model)
print('input type: %s' % model.input_type)
print('output type: %s' % model.output_type)
input type: PyObjectType() output type: TupleType(TensorType((300,), 'float32'), TensorType((300,), 'float32'))
metrics = {k: tf.reduce_mean(v) for k, v in compiler.metric_tensors.items()}
Magic numbers.
LEARNING_RATE = 0.05
KEEP_PROB = 0.75
BATCH_SIZE = 100
EPOCHS = 20
EMBEDDING_LEARNING_RATE_FACTOR = 0.1
Training with Adagrad.
train_feed_dict = {keep_prob_ph: KEEP_PROB}
loss = tf.reduce_sum(compiler.metric_tensors['all_loss'])
opt = tf.train.AdagradOptimizer(LEARNING_RATE)
Important detail from section 5.3 of [Tai et al.]((http://arxiv.org/pdf/1503.00075.pdf); downscale the gradients for the word embedding vectors 10x otherwise we overfit horribly.
grads_and_vars = opt.compute_gradients(loss)
found = 0
for i, (grad, var) in enumerate(grads_and_vars):
if var == word_embedding.weights:
found += 1
grad = tf.scalar_mul(EMBEDDING_LEARNING_RATE_FACTOR, grad)
grads_and_vars[i] = (grad, var)
assert found == 1 # internal consistency check
train = opt.apply_gradients(grads_and_vars)
saver = tf.train.Saver()
/usr/local/google/home/madscience/nuke/v3/local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py:91: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory. "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
The TF graph is now complete; initialize the variables.
sess.run(tf.global_variables_initializer())
Start by defining a function that does a single step of training on a batch and returns the loss.
def train_step(batch):
train_feed_dict[compiler.loom_input_tensor] = batch
_, batch_loss = sess.run([train, loss], train_feed_dict)
return batch_loss
Now similarly for an entire epoch of training.
def train_epoch(train_set):
return sum(train_step(batch) for batch in td.group_by_batches(train_set, BATCH_SIZE))
Use Compiler.build_loom_inputs()
to transform train_trees
into individual loom inputs (i.e. wiring diagrams) that we can use to actually run the model.
train_set = compiler.build_loom_inputs(train_trees)
Use Compiler.build_feed_dict()
to build a feed dictionary for validation on the dev set. This is marginally faster and more convenient than calling build_loom_inputs
. We used build_loom_inputs
on the train set so that we can shuffle the individual wiring diagrams into different batches for each epoch.
dev_feed_dict = compiler.build_feed_dict(dev_trees)
Define a function to do an eval on the dev set and pretty-print some stats, returning accuracy on the dev set.
def dev_eval(epoch, train_loss):
dev_metrics = sess.run(metrics, dev_feed_dict)
dev_loss = dev_metrics['all_loss']
dev_accuracy = ['%s: %.2f' % (k, v * 100) for k, v in
sorted(dev_metrics.items()) if k.endswith('hits')]
print('epoch:%4d, train_loss: %.3e, dev_loss_avg: %.3e, dev_accuracy:\n [%s]'
% (epoch, train_loss, dev_loss, ' '.join(dev_accuracy)))
return dev_metrics['root_hits']
Run the main training loop, saving the model after each epoch if it has the best accuracy on the dev set. Use the td.epochs
utility function to memoize the loom inputs and shuffle them after every epoch of training.
best_accuracy = 0.0
save_path = os.path.join(data_dir, 'sentiment_model')
for epoch, shuffled in enumerate(td.epochs(train_set, EPOCHS), 1):
train_loss = train_epoch(shuffled)
accuracy = dev_eval(epoch, train_loss)
if accuracy > best_accuracy:
best_accuracy = accuracy
checkpoint_path = saver.save(sess, save_path, global_step=epoch)
print('model saved in file: %s' % checkpoint_path)
epoch: 1, train_loss: 2.262e+05, dev_loss_avg: 5.253e-01, dev_accuracy: [all_binary_hits: 88.94 all_hits: 78.30 root_binary_hits: 82.00 root_hits: 40.51] model saved in file: /tmp/tmpPhKqpj/sentiment_model-1 epoch: 2, train_loss: 1.590e+05, dev_loss_avg: 4.602e-01, dev_accuracy: [all_binary_hits: 90.41 all_hits: 81.00 root_binary_hits: 83.60 root_hits: 46.59] model saved in file: /tmp/tmpPhKqpj/sentiment_model-2 epoch: 3, train_loss: 1.443e+05, dev_loss_avg: 4.371e-01, dev_accuracy: [all_binary_hits: 91.17 all_hits: 82.02 root_binary_hits: 85.21 root_hits: 48.68] model saved in file: /tmp/tmpPhKqpj/sentiment_model-3 epoch: 4, train_loss: 1.357e+05, dev_loss_avg: 4.242e-01, dev_accuracy: [all_binary_hits: 91.63 all_hits: 82.45 root_binary_hits: 87.04 root_hits: 49.86] model saved in file: /tmp/tmpPhKqpj/sentiment_model-4 epoch: 5, train_loss: 1.297e+05, dev_loss_avg: 4.190e-01, dev_accuracy: [all_binary_hits: 92.07 all_hits: 82.64 root_binary_hits: 88.19 root_hits: 51.50] model saved in file: /tmp/tmpPhKqpj/sentiment_model-5 epoch: 6, train_loss: 1.246e+05, dev_loss_avg: 4.175e-01, dev_accuracy: [all_binary_hits: 91.77 all_hits: 82.52 root_binary_hits: 86.81 root_hits: 49.41] epoch: 7, train_loss: 1.209e+05, dev_loss_avg: 4.164e-01, dev_accuracy: [all_binary_hits: 92.08 all_hits: 82.81 root_binary_hits: 87.61 root_hits: 50.41] epoch: 8, train_loss: 1.172e+05, dev_loss_avg: 4.177e-01, dev_accuracy: [all_binary_hits: 91.92 all_hits: 82.88 root_binary_hits: 87.50 root_hits: 50.14] epoch: 9, train_loss: 1.143e+05, dev_loss_avg: 4.158e-01, dev_accuracy: [all_binary_hits: 92.16 all_hits: 82.84 root_binary_hits: 87.73 root_hits: 49.86] epoch: 10, train_loss: 1.120e+05, dev_loss_avg: 4.152e-01, dev_accuracy: [all_binary_hits: 92.27 all_hits: 82.91 root_binary_hits: 87.50 root_hits: 50.77] epoch: 11, train_loss: 1.094e+05, dev_loss_avg: 4.179e-01, dev_accuracy: [all_binary_hits: 92.35 all_hits: 82.98 root_binary_hits: 88.76 root_hits: 50.14] epoch: 12, train_loss: 1.074e+05, dev_loss_avg: 4.221e-01, dev_accuracy: [all_binary_hits: 91.96 all_hits: 83.03 root_binary_hits: 87.16 root_hits: 50.05] epoch: 13, train_loss: 1.055e+05, dev_loss_avg: 4.224e-01, dev_accuracy: [all_binary_hits: 92.04 all_hits: 83.05 root_binary_hits: 87.50 root_hits: 50.05] epoch: 14, train_loss: 1.039e+05, dev_loss_avg: 4.204e-01, dev_accuracy: [all_binary_hits: 92.38 all_hits: 83.01 root_binary_hits: 88.76 root_hits: 51.32] epoch: 15, train_loss: 1.017e+05, dev_loss_avg: 4.229e-01, dev_accuracy: [all_binary_hits: 92.52 all_hits: 82.92 root_binary_hits: 88.53 root_hits: 49.68] epoch: 16, train_loss: 1.004e+05, dev_loss_avg: 4.278e-01, dev_accuracy: [all_binary_hits: 92.57 all_hits: 83.00 root_binary_hits: 88.42 root_hits: 52.13] model saved in file: /tmp/tmpPhKqpj/sentiment_model-16 epoch: 17, train_loss: 9.887e+04, dev_loss_avg: 4.316e-01, dev_accuracy: [all_binary_hits: 92.31 all_hits: 82.87 root_binary_hits: 87.73 root_hits: 51.04] epoch: 18, train_loss: 9.742e+04, dev_loss_avg: 4.328e-01, dev_accuracy: [all_binary_hits: 92.28 all_hits: 82.90 root_binary_hits: 88.42 root_hits: 51.59] epoch: 19, train_loss: 9.633e+04, dev_loss_avg: 4.338e-01, dev_accuracy: [all_binary_hits: 92.41 all_hits: 82.86 root_binary_hits: 88.53 root_hits: 51.68] epoch: 20, train_loss: 9.474e+04, dev_loss_avg: 4.368e-01, dev_accuracy: [all_binary_hits: 92.23 all_hits: 82.90 root_binary_hits: 87.96 root_hits: 50.14]
The model starts to overfit pretty quickly even with dropout, as the LSTM begins to memorize the training set (which is rather small).
Restore the model from the last checkpoint, where we saw the best accuracy on the dev set.
saver.restore(sess, checkpoint_path)
See how we did.
test_results = sorted(sess.run(metrics, compiler.build_feed_dict(test_trees)).items())
print(' loss: [%s]' % ' '.join(
'%s: %.3e' % (name.rsplit('_', 1)[0], v)
for name, v in test_results if name.endswith('_loss')))
print('accuracy: [%s]' % ' '.join(
'%s: %.2f' % (name.rsplit('_', 1)[0], v * 100)
for name, v in test_results if name.endswith('_hits')))
loss: [all: 4.276e-01 root: 1.121e+00] accuracy: [all_binary: 92.37 all: 83.13 root_binary: 89.29 root: 51.90]
Not bad! See section 3.5.1 of our paper for discussion and a comparison of these results to the state of the art.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。