Skip to content

Commit 941d404

Browse files
author
yyf
committed
Adding embedding visualization
1 parent de0504a commit 941d404

1 file changed

Lines changed: 69 additions & 0 deletions

File tree

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#! coding: utf-8
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
import re
8+
import os
9+
import io
10+
import numpy as np
11+
import tensorflow as tf
12+
from tensorflow.contrib.tensorboard.plugins import projector
13+
14+
15+
def word_embedding(LOG_DIR, emb, tsv):
16+
"""Embedding visualization using tensorboard.
17+
18+
Args:
19+
LOG_DIR:
20+
emb: The embedding tensor with shape n x d which each line representing a vector of a word.
21+
tsv: Table separated values. If the meta in tsv only one column, then the number of lines should equals n. Else
22+
tsv should have n + 1 lines with the first line contains the header.
23+
If tsv is a string, then it represents the file path where the tsv file resides. Other wise it should be a list.
24+
"""
25+
# Create randomly initialized embedding weights which will be trained.
26+
g = tf.Graph()
27+
with g.as_default():
28+
29+
emb = np.array(emb)
30+
n, d = emb.shape
31+
embedding_var = tf.Variable(emb, dtype=tf.float32, name='word_embedding')
32+
33+
# Format: tensorflow/tensorboard/plugins/projector/projector_config.proto
34+
config = projector.ProjectorConfig()
35+
36+
# You can add multiple embeddings. Here we add only one.
37+
embedding = config.embeddings.add()
38+
embedding.tensor_name = embedding_var.name
39+
# Link this tensor to its metadata file (e.g. labels).
40+
if isinstance(tsv, str):
41+
embedding.metadata_path = os.path.join(LOG_DIR, tsv)
42+
else:
43+
meta_path = os.path.join(LOG_DIR, "_meatadata.tsv")
44+
with io.open(meta_path, "w") as f:
45+
for line in tsv:
46+
f.write(line + '\n')
47+
embedding.metadata_path = meta_path
48+
# Use the same LOG_DIR where you stored your checkpoint.
49+
summary_writer = tf.summary.FileWriter(LOG_DIR)
50+
51+
# The next line writes a projector_config.pbtxt in the LOG_DIR. TensorBoard will
52+
# read this file during startup.
53+
projector.visualize_embeddings(summary_writer, config)
54+
55+
with tf.Session(graph=g) as sess:
56+
tf.global_variables_initializer().run()
57+
tf.train.Saver().save(sess, os.path.join(LOG_DIR, "model.ckpt"), global_step=1)
58+
59+
60+
if __name__ == "__main__":
61+
62+
with tf.Session() as sess:
63+
n, d = 1000, 200
64+
emb = tf.random_normal([n, d])
65+
tsv = [str(i) for i in range(n)]
66+
word_embedding('/tmp/log', emb.eval(), tsv)
67+
68+
# tesorboard --logdir=/tmo/log
69+

0 commit comments

Comments
 (0)