Skip to content

Commit beda6c4

Browse files
committed
Add a Transfer Learning example for image recognition. SciSharp#248
1 parent ddbbe06 commit beda6c4

10 files changed

Lines changed: 301 additions & 5 deletions

File tree

graph/InceptionV3.meta

2.97 MB
Binary file not shown.

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ public static Tensor transpose<T1>(T1 a, int[] perm = null, string name = "trans
3434
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1)
3535
=> gen_array_ops.squeeze(input, axis, name);
3636

37+
/// <summary>
38+
/// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
39+
/// </summary>
40+
/// <param name="values"></param>
41+
/// <param name="axis"></param>
42+
/// <param name="name"></param>
43+
/// <returns></returns>
44+
public static Tensor stack(object values, int axis = 0, string name = "stack")
45+
=> array_ops.stack(values, axis, name: name);
46+
3747
public static Tensor one_hot(Tensor indices, int depth,
3848
Tensor on_value = null,
3949
Tensor off_value = null,

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.IO;
45

56
namespace Tensorflow
67
{
78
public static partial class tf
89
{
10+
public static GFile gfile = new GFile();
911
public static Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
1012

1113
public static gen_image_ops image => new gen_image_ops();

src/TensorFlowNET.Core/IO/gfile.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Text;
5+
6+
namespace Tensorflow.IO
7+
{
8+
public class GFile
9+
{
10+
/// <summary>
11+
/// Recursive directory tree generator for directories.
12+
/// </summary>
13+
/// <param name="top">a Directory name</param>
14+
/// <param name="in_order">Traverse in order if True, post order if False.</param>
15+
public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true)
16+
{
17+
return walk_v2(top, in_order);
18+
}
19+
20+
private IEnumerable<(string, string[], string[])> walk_v2(string top, bool topdown)
21+
{
22+
var subdirs = Directory.GetDirectories(top);
23+
var files = Directory.GetFiles(top);
24+
25+
var here = (top, subdirs, files);
26+
27+
if (subdirs.Length == 0)
28+
yield return here;
29+
else
30+
foreach (var dir in subdirs)
31+
foreach (var f in walk_v2(dir, topdown))
32+
yield return f;
33+
}
34+
}
35+
}

src/TensorFlowNET.Core/Operations/gen_image_ops.py.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,48 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using static Tensorflow.Python;
45

56
namespace Tensorflow
67
{
78
public class gen_image_ops
89
{
910
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
1011

12+
public Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool saturate = false, string name= null)
13+
{
14+
if (dtype == image.dtype)
15+
return array_ops.identity(image, name: name);
16+
17+
return with(ops.name_scope(name, "convert_image", image), scope =>
18+
{
19+
name = scope;
20+
21+
if (image.dtype.is_integer() && dtype.is_integer())
22+
{
23+
throw new NotImplementedException("convert_image_dtype is_integer");
24+
}
25+
else if (image.dtype.is_floating() && dtype.is_floating())
26+
{
27+
throw new NotImplementedException("convert_image_dtype is_floating");
28+
}
29+
else
30+
{
31+
if (image.dtype.is_integer())
32+
{
33+
// Converting to float: first cast, then scale. No saturation possible.
34+
var cast = math_ops.cast(image, dtype);
35+
var scale = 1.0f / image.dtype.max();
36+
return math_ops.multiply(cast, scale, name: name);
37+
}
38+
else
39+
{
40+
throw new NotImplementedException("convert_image_dtype is_integer");
41+
}
42+
}
43+
});
44+
}
45+
1146
public Tensor decode_jpeg(Tensor contents,
1247
int channels = 0,
1348
int ratio = 1,

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public static Tensor abs(Tensor x, string name = null)
2424
});
2525
}
2626

27-
public static Tensor add(Tensor x, Tensor y, string name = null)
27+
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
2828
=> gen_math_ops.add(x, y, name);
2929

3030
/// <summary>
@@ -68,10 +68,10 @@ public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, s
6868
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
6969
=> gen_math_ops.equal(x, y, name: name);
7070

71-
public static Tensor multiply(Tensor x, Tensor y, string name = null)
71+
public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
7272
=> gen_math_ops.mul(x, y, name: name);
7373

74-
public static Tensor mul_no_nan(Tensor x, Tensor y, string name = null)
74+
public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null)
7575
=> gen_math_ops.mul_no_nan(x, y, name: name);
7676

7777
/// <summary>

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,26 @@ public static TF_DataType as_ref(this TF_DataType type)
123123
type;
124124
}
125125

126+
public static int max(this TF_DataType type)
127+
{
128+
switch (type)
129+
{
130+
case TF_DataType.TF_UINT8:
131+
return 255;
132+
default:
133+
throw new NotImplementedException($"max {type.name()}");
134+
}
135+
}
136+
126137
public static bool is_complex(this TF_DataType type)
127138
{
128139
return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128;
129140
}
130141

131142
public static bool is_integer(this TF_DataType type)
132143
{
133-
return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64;
144+
return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 ||
145+
type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64;
134146
}
135147

136148
public static bool is_floating(this TF_DataType type)

src/TensorFlowNET.Core/ops.GraphKeys.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public static class GraphKeys
3737

3838
public static string GLOBAL_STEP = GLOBAL_STEP = "global_step";
3939

40-
public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" };
40+
public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
4141
/// <summary>
4242
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
4343
/// </summary>
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Text;
7+
using Tensorflow;
8+
using TensorFlowNET.Examples.Utility;
9+
using static Tensorflow.Python;
10+
11+
namespace TensorFlowNET.Examples.ImageProcess
12+
{
13+
/// <summary>
14+
/// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet
15+
/// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this
16+
/// by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
17+
///
18+
/// https://www.tensorflow.org/hub/tutorials/image_retraining
19+
/// </summary>
20+
public class RetrainImageClassifier : IExample
21+
{
22+
public int Priority => 16;
23+
24+
public bool Enabled { get; set; } = false;
25+
public bool ImportGraph { get; set; } = true;
26+
27+
public string Name => "Retrain Image Classifier";
28+
29+
const string data_dir = "retrain_images";
30+
string summaries_dir = Path.Join(data_dir, "retrain_logs");
31+
string image_dir = Path.Join(data_dir, "flower_photos");
32+
string bottleneck_dir = Path.Join(data_dir, "bottleneck");
33+
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
34+
float testing_percentage = 0.1f;
35+
float validation_percentage = 0.1f;
36+
Tensor resized_image_tensor;
37+
Dictionary<string, Dictionary<string, string[]>> image_lists;
38+
39+
public bool Run()
40+
{
41+
PrepareData();
42+
43+
var graph = tf.Graph().as_default();
44+
tf.train.import_meta_graph("graph/InceptionV3.meta");
45+
Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");
46+
Tensor resized_image_tensor = graph.OperationByName("Placeholder");
47+
48+
var sw = new Stopwatch();
49+
50+
with(tf.Session(graph), sess =>
51+
{
52+
// Initialize all weights: for the module to their pretrained values,
53+
// and for the newly added retraining layer to random initial values.
54+
var init = tf.global_variables_initializer();
55+
sess.run(init);
56+
57+
var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();
58+
59+
// We'll make sure we've calculated the 'bottleneck' image summaries and
60+
// cached them on disk.
61+
cache_bottlenecks(sess, image_lists, image_dir,
62+
bottleneck_dir, jpeg_data_tensor,
63+
decoded_image_tensor, resized_image_tensor,
64+
bottleneck_tensor, tfhub_module);
65+
});
66+
67+
return false;
68+
}
69+
70+
/// <summary>
71+
/// Ensures all the training, testing, and validation bottlenecks are cached.
72+
/// </summary>
73+
/// <param name="sess"></param>
74+
/// <param name="image_lists"></param>
75+
/// <param name="image_dir"></param>
76+
/// <param name="bottleneck_dir"></param>
77+
/// <param name="jpeg_data_tensor"></param>
78+
/// <param name="decoded_image_tensor"></param>
79+
/// <param name="resized_image_tensor"></param>
80+
/// <param name="bottleneck_tensor"></param>
81+
/// <param name="tfhub_module"></param>
82+
private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
83+
string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
84+
Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
85+
{
86+
int how_many_bottlenecks = 0;
87+
foreach(var (label_name, label_lists) in image_lists)
88+
{
89+
foreach(var category in new string[] { "training", "testing", "validation" })
90+
{
91+
var category_list = label_lists[category];
92+
foreach(var (index, unused_base_name) in enumerate(category_list))
93+
{
94+
get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category,
95+
bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
96+
resized_input_tensor, bottleneck_tensor, module_name);
97+
}
98+
}
99+
}
100+
}
101+
102+
private void get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
103+
string label_name, int index, string image_dir, string category, string bottleneck_dir,
104+
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
105+
Tensor bottleneck_tensor, string module_name)
106+
{
107+
var label_lists = image_lists[label_name];
108+
var sub_dir_path = Path.Join(image_dir, label_name);
109+
}
110+
111+
public void PrepareData()
112+
{
113+
// get a set of images to teach the network about the new classes
114+
string fileName = "flower_photos.tgz";
115+
string url = $"http://download.tensorflow.org/models/{fileName}";
116+
Web.Download(url, data_dir, fileName);
117+
Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
118+
119+
// download graph meta data
120+
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
121+
Web.Download(url, "graph", "InceptionV3.meta");
122+
123+
// Prepare necessary directories that can be used during training
124+
Directory.CreateDirectory(summaries_dir);
125+
Directory.CreateDirectory(bottleneck_dir);
126+
127+
// Look at the folder structure, and create lists of all the images.
128+
image_lists = create_image_lists();
129+
var class_count = len(image_lists);
130+
if (class_count == 0)
131+
print($"No valid folders of images found at {image_dir}");
132+
if (class_count == 1)
133+
print("Only one valid folder of images found at " +
134+
image_dir +
135+
" - multiple classes are needed for classification.");
136+
}
137+
138+
private (Tensor, Tensor) add_jpeg_decoding()
139+
{
140+
// height, width, depth
141+
var input_dim = (299, 299, 3);
142+
var jpeg_data = tf.placeholder(tf.chars, name: "DecodeJPGInput");
143+
var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3);
144+
// Convert from full range of uint8 to range [0,1] of float32.
145+
var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32);
146+
var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0);
147+
var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 });
148+
var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32);
149+
var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int);
150+
return (jpeg_data, resized_image);
151+
}
152+
153+
/// <summary>
154+
/// Builds a list of training images from the file system.
155+
/// </summary>
156+
private Dictionary<string, Dictionary<string, string[]>> create_image_lists()
157+
{
158+
var sub_dirs = tf.gfile.Walk(image_dir)
159+
.Select(x => x.Item1)
160+
.OrderBy(x => x)
161+
.ToArray();
162+
163+
var result = new Dictionary<string, Dictionary<string, string[]>>();
164+
165+
foreach(var sub_dir in sub_dirs)
166+
{
167+
var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last();
168+
print($"Looking for images in '{dir_name}'");
169+
var file_list = Directory.GetFiles(sub_dir);
170+
if (len(file_list) < 20)
171+
print($"WARNING: Folder has less than 20 images, which may cause issues.");
172+
173+
var label_name = dir_name.ToLower();
174+
result[label_name] = new Dictionary<string, string[]>();
175+
int testing_count = (int)Math.Floor(file_list.Length * testing_percentage);
176+
int validation_count = (int)Math.Floor(file_list.Length * validation_percentage);
177+
result[label_name]["testing"] = file_list.Take(testing_count).ToArray();
178+
result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray();
179+
result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray();
180+
}
181+
182+
return result;
183+
}
184+
}
185+
}

test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.IO;
55
using System.Linq;
6+
using System.Security.Cryptography;
67
using System.Text;
78
using System.Text.RegularExpressions;
89
using TensorFlowNET.Examples.Utility;
@@ -159,5 +160,21 @@ private static (int[][][], int[]) _pad_sequences(int[][][] sequences, int[] pad_
159160

160161
return (sequences, sequence_length);
161162
}
163+
164+
public static string CalculateMD5Hash(string input)
165+
{
166+
// step 1, calculate MD5 hash from input
167+
MD5 md5 = System.Security.Cryptography.MD5.Create();
168+
byte[] inputBytes = System.Text.Encoding.ASCII.GetBytes(input);
169+
byte[] hash = md5.ComputeHash(inputBytes);
170+
171+
// step 2, convert byte array to hex string
172+
StringBuilder sb = new StringBuilder();
173+
for (int i = 0; i < hash.Length; i++)
174+
{
175+
sb.Append(hash[i].ToString("X2"));
176+
}
177+
return sb.ToString();
178+
}
162179
}
163180
}

0 commit comments

Comments
 (0)