|
| 1 | +#! coding: utf-8 |
| 2 | + |
| 3 | +from __future__ import absolute_import |
| 4 | +from __future__ import division |
| 5 | +from __future__ import print_function |
| 6 | + |
| 7 | +import re |
| 8 | +import os |
| 9 | +import io |
| 10 | +import numpy as np |
| 11 | +import tensorflow as tf |
| 12 | +from tensorflow.contrib.tensorboard.plugins import projector |
| 13 | + |
| 14 | + |
| 15 | +def word_embedding(LOG_DIR, emb, tsv): |
| 16 | + """Embedding visualization using tensorboard. |
| 17 | + |
| 18 | + Args: |
| 19 | + LOG_DIR: |
| 20 | + emb: The embedding tensor with shape n x d which each line representing a vector of a word. |
| 21 | + tsv: Table separated values. If the meta in tsv only one column, then the number of lines should equals n. Else |
| 22 | + tsv should have n + 1 lines with the first line contains the header. |
| 23 | + If tsv is a string, then it represents the file path where the tsv file resides. Other wise it should be a list. |
| 24 | + """ |
| 25 | + # Create randomly initialized embedding weights which will be trained. |
| 26 | + g = tf.Graph() |
| 27 | + with g.as_default(): |
| 28 | + |
| 29 | + emb = np.array(emb) |
| 30 | + n, d = emb.shape |
| 31 | + embedding_var = tf.Variable(emb, dtype=tf.float32, name='word_embedding') |
| 32 | + |
| 33 | + # Format: tensorflow/tensorboard/plugins/projector/projector_config.proto |
| 34 | + config = projector.ProjectorConfig() |
| 35 | + |
| 36 | + # You can add multiple embeddings. Here we add only one. |
| 37 | + embedding = config.embeddings.add() |
| 38 | + embedding.tensor_name = embedding_var.name |
| 39 | + # Link this tensor to its metadata file (e.g. labels). |
| 40 | + if isinstance(tsv, str): |
| 41 | + embedding.metadata_path = os.path.join(LOG_DIR, tsv) |
| 42 | + else: |
| 43 | + meta_path = os.path.join(LOG_DIR, "_meatadata.tsv") |
| 44 | + with io.open(meta_path, "w") as f: |
| 45 | + for line in tsv: |
| 46 | + f.write(line + '\n') |
| 47 | + embedding.metadata_path = meta_path |
| 48 | + # Use the same LOG_DIR where you stored your checkpoint. |
| 49 | + summary_writer = tf.summary.FileWriter(LOG_DIR) |
| 50 | + |
| 51 | + # The next line writes a projector_config.pbtxt in the LOG_DIR. TensorBoard will |
| 52 | + # read this file during startup. |
| 53 | + projector.visualize_embeddings(summary_writer, config) |
| 54 | + |
| 55 | + with tf.Session(graph=g) as sess: |
| 56 | + tf.global_variables_initializer().run() |
| 57 | + tf.train.Saver().save(sess, os.path.join(LOG_DIR, "model.ckpt"), global_step=1) |
| 58 | + |
| 59 | + |
| 60 | +if __name__ == "__main__": |
| 61 | + |
| 62 | + with tf.Session() as sess: |
| 63 | + n, d = 1000, 200 |
| 64 | + emb = tf.random_normal([n, d]) |
| 65 | + tsv = [str(i) for i in range(n)] |
| 66 | + word_embedding('/tmp/log', emb.eval(), tsv) |
| 67 | + |
| 68 | +# tesorboard --logdir=/tmo/log |
| 69 | + |
0 commit comments