|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.Diagnostics; |
| 4 | +using System.IO; |
| 5 | +using System.Linq; |
| 6 | +using System.Text; |
| 7 | +using Tensorflow; |
| 8 | +using TensorFlowNET.Examples.Utility; |
| 9 | +using static Tensorflow.Python; |
| 10 | + |
| 11 | +namespace TensorFlowNET.Examples.ImageProcess |
| 12 | +{ |
| 13 | + /// <summary> |
| 14 | + /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet |
| 15 | + /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this |
| 16 | + /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model. |
| 17 | + /// |
| 18 | + /// https://www.tensorflow.org/hub/tutorials/image_retraining |
| 19 | + /// </summary> |
| 20 | + public class RetrainImageClassifier : IExample |
| 21 | + { |
| 22 | + public int Priority => 16; |
| 23 | + |
| 24 | + public bool Enabled { get; set; } = false; |
| 25 | + public bool ImportGraph { get; set; } = true; |
| 26 | + |
| 27 | + public string Name => "Retrain Image Classifier"; |
| 28 | + |
| 29 | + const string data_dir = "retrain_images"; |
| 30 | + string summaries_dir = Path.Join(data_dir, "retrain_logs"); |
| 31 | + string image_dir = Path.Join(data_dir, "flower_photos"); |
| 32 | + string bottleneck_dir = Path.Join(data_dir, "bottleneck"); |
| 33 | + string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; |
| 34 | + float testing_percentage = 0.1f; |
| 35 | + float validation_percentage = 0.1f; |
| 36 | + Tensor resized_image_tensor; |
| 37 | + Dictionary<string, Dictionary<string, string[]>> image_lists; |
| 38 | + |
| 39 | + public bool Run() |
| 40 | + { |
| 41 | + PrepareData(); |
| 42 | + |
| 43 | + var graph = tf.Graph().as_default(); |
| 44 | + tf.train.import_meta_graph("graph/InceptionV3.meta"); |
| 45 | + Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze"); |
| 46 | + Tensor resized_image_tensor = graph.OperationByName("Placeholder"); |
| 47 | + |
| 48 | + var sw = new Stopwatch(); |
| 49 | + |
| 50 | + with(tf.Session(graph), sess => |
| 51 | + { |
| 52 | + // Initialize all weights: for the module to their pretrained values, |
| 53 | + // and for the newly added retraining layer to random initial values. |
| 54 | + var init = tf.global_variables_initializer(); |
| 55 | + sess.run(init); |
| 56 | + |
| 57 | + var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding(); |
| 58 | + |
| 59 | + // We'll make sure we've calculated the 'bottleneck' image summaries and |
| 60 | + // cached them on disk. |
| 61 | + cache_bottlenecks(sess, image_lists, image_dir, |
| 62 | + bottleneck_dir, jpeg_data_tensor, |
| 63 | + decoded_image_tensor, resized_image_tensor, |
| 64 | + bottleneck_tensor, tfhub_module); |
| 65 | + }); |
| 66 | + |
| 67 | + return false; |
| 68 | + } |
| 69 | + |
| 70 | + /// <summary> |
| 71 | + /// Ensures all the training, testing, and validation bottlenecks are cached. |
| 72 | + /// </summary> |
| 73 | + /// <param name="sess"></param> |
| 74 | + /// <param name="image_lists"></param> |
| 75 | + /// <param name="image_dir"></param> |
| 76 | + /// <param name="bottleneck_dir"></param> |
| 77 | + /// <param name="jpeg_data_tensor"></param> |
| 78 | + /// <param name="decoded_image_tensor"></param> |
| 79 | + /// <param name="resized_image_tensor"></param> |
| 80 | + /// <param name="bottleneck_tensor"></param> |
| 81 | + /// <param name="tfhub_module"></param> |
| 82 | + private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, |
| 83 | + string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, |
| 84 | + Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) |
| 85 | + { |
| 86 | + int how_many_bottlenecks = 0; |
| 87 | + foreach(var (label_name, label_lists) in image_lists) |
| 88 | + { |
| 89 | + foreach(var category in new string[] { "training", "testing", "validation" }) |
| 90 | + { |
| 91 | + var category_list = label_lists[category]; |
| 92 | + foreach(var (index, unused_base_name) in enumerate(category_list)) |
| 93 | + { |
| 94 | + get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, |
| 95 | + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, |
| 96 | + resized_input_tensor, bottleneck_tensor, module_name); |
| 97 | + } |
| 98 | + } |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + private void get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, |
| 103 | + string label_name, int index, string image_dir, string category, string bottleneck_dir, |
| 104 | + Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, |
| 105 | + Tensor bottleneck_tensor, string module_name) |
| 106 | + { |
| 107 | + var label_lists = image_lists[label_name]; |
| 108 | + var sub_dir_path = Path.Join(image_dir, label_name); |
| 109 | + } |
| 110 | + |
| 111 | + public void PrepareData() |
| 112 | + { |
| 113 | + // get a set of images to teach the network about the new classes |
| 114 | + string fileName = "flower_photos.tgz"; |
| 115 | + string url = $"http://download.tensorflow.org/models/{fileName}"; |
| 116 | + Web.Download(url, data_dir, fileName); |
| 117 | + Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); |
| 118 | + |
| 119 | + // download graph meta data |
| 120 | + url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta"; |
| 121 | + Web.Download(url, "graph", "InceptionV3.meta"); |
| 122 | + |
| 123 | + // Prepare necessary directories that can be used during training |
| 124 | + Directory.CreateDirectory(summaries_dir); |
| 125 | + Directory.CreateDirectory(bottleneck_dir); |
| 126 | + |
| 127 | + // Look at the folder structure, and create lists of all the images. |
| 128 | + image_lists = create_image_lists(); |
| 129 | + var class_count = len(image_lists); |
| 130 | + if (class_count == 0) |
| 131 | + print($"No valid folders of images found at {image_dir}"); |
| 132 | + if (class_count == 1) |
| 133 | + print("Only one valid folder of images found at " + |
| 134 | + image_dir + |
| 135 | + " - multiple classes are needed for classification."); |
| 136 | + } |
| 137 | + |
| 138 | + private (Tensor, Tensor) add_jpeg_decoding() |
| 139 | + { |
| 140 | + // height, width, depth |
| 141 | + var input_dim = (299, 299, 3); |
| 142 | + var jpeg_data = tf.placeholder(tf.chars, name: "DecodeJPGInput"); |
| 143 | + var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3); |
| 144 | + // Convert from full range of uint8 to range [0,1] of float32. |
| 145 | + var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32); |
| 146 | + var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0); |
| 147 | + var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 }); |
| 148 | + var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32); |
| 149 | + var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int); |
| 150 | + return (jpeg_data, resized_image); |
| 151 | + } |
| 152 | + |
| 153 | + /// <summary> |
| 154 | + /// Builds a list of training images from the file system. |
| 155 | + /// </summary> |
| 156 | + private Dictionary<string, Dictionary<string, string[]>> create_image_lists() |
| 157 | + { |
| 158 | + var sub_dirs = tf.gfile.Walk(image_dir) |
| 159 | + .Select(x => x.Item1) |
| 160 | + .OrderBy(x => x) |
| 161 | + .ToArray(); |
| 162 | + |
| 163 | + var result = new Dictionary<string, Dictionary<string, string[]>>(); |
| 164 | + |
| 165 | + foreach(var sub_dir in sub_dirs) |
| 166 | + { |
| 167 | + var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last(); |
| 168 | + print($"Looking for images in '{dir_name}'"); |
| 169 | + var file_list = Directory.GetFiles(sub_dir); |
| 170 | + if (len(file_list) < 20) |
| 171 | + print($"WARNING: Folder has less than 20 images, which may cause issues."); |
| 172 | + |
| 173 | + var label_name = dir_name.ToLower(); |
| 174 | + result[label_name] = new Dictionary<string, string[]>(); |
| 175 | + int testing_count = (int)Math.Floor(file_list.Length * testing_percentage); |
| 176 | + int validation_count = (int)Math.Floor(file_list.Length * validation_percentage); |
| 177 | + result[label_name]["testing"] = file_list.Take(testing_count).ToArray(); |
| 178 | + result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray(); |
| 179 | + result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray(); |
| 180 | + } |
| 181 | + |
| 182 | + return result; |
| 183 | + } |
| 184 | + } |
| 185 | +} |
0 commit comments