Skip to content

Commit 3d7ff13

Browse files
committed
change constant creation method.
1 parent de082e1 commit 3d7ff13

12 files changed

Lines changed: 240 additions & 28 deletions

File tree

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public static partial class tf
8+
{
9+
public static object get_collection(string key, string scope = "") => get_default_graph()
10+
.get_collection(key, scope: scope);
11+
}
12+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations.Losses
6+
{
7+
class losses_impl
8+
{
9+
}
10+
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ public static Tensor shape(Tensor input, string name = "", TF_DataType out_type
8282
return shape_internal(input, name, optimize: true, out_type: out_type);
8383
}
8484

85-
public static Tensor size(Tensor input, string name = "", TF_DataType out_type = TF_DataType.TF_INT32)
85+
public static Tensor size(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
8686
{
87-
return size_internal(input, name, optimize: true, out_type: out_type);
87+
return size_internal(input, name, optimize: optimize, out_type: out_type);
8888
}
8989

9090
private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
@@ -132,6 +132,7 @@ private static Tensor size_internal(Tensor input, string name = "", bool optimiz
132132
else
133133
{
134134
// result = gen_array_ops.shape();
135+
throw new NotImplementedException("array_ops.size_internal");
135136
}
136137

137138
return null;

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,36 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
4646
var feed_dict_tensor = new Dictionary<object, object>();
4747
var feed_map = new Dictionary<object, object>();
4848

49+
Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
50+
{
51+
return new (object, object)[] { (item.Key, item.Value) };
52+
};
53+
4954
// Validate and process feed_dict.
5055
if (feed_dict != null)
5156
{
52-
foreach(var subfeed in feed_dict)
57+
foreach (var feed in feed_dict)
5358
{
54-
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
55-
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
56-
switch(subfeed.Value)
59+
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
5760
{
58-
case float floatVal:
59-
feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
60-
break;
61-
case int intVal:
62-
feed_dict_tensor[subfeed_t] = (NDArray)intVal;
63-
break;
64-
case string str:
65-
feed_dict_tensor[subfeed_t] = (NDArray)str;
66-
break;
67-
default:
68-
throw new NotImplementedException("_run subfeed");
61+
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
62+
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
63+
switch (subfeed_val)
64+
{
65+
case float floatVal:
66+
feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
67+
break;
68+
case int intVal:
69+
feed_dict_tensor[subfeed_t] = (NDArray)intVal;
70+
break;
71+
case string str:
72+
feed_dict_tensor[subfeed_t] = (NDArray)str;
73+
break;
74+
default:
75+
throw new NotImplementedException("_run subfeed");
76+
}
77+
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
6978
}
70-
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
7179
}
7280
}
7381

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtIn
2424
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true);
2525
}
2626

27-
private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast)
27+
public static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast)
2828
{
2929
if (tf.context.executing_eagerly())
3030
{

src/TensorFlowNET.Core/Tensors/tf.constant.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,26 @@ namespace Tensorflow
77
{
88
public static partial class tf
99
{
10-
public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name);
10+
// public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name);
11+
12+
public static Tensor constant(object value,
13+
TF_DataType dtype = TF_DataType.DtInvalid,
14+
int[] shape = null,
15+
string name = "Const",
16+
bool verify_shape = false) => constant_op._constant_impl(value,
17+
dtype,
18+
shape,
19+
name,
20+
verify_shape: verify_shape,
21+
allow_broadcast: false);
1122

1223
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name);
24+
25+
public static Tensor size(Tensor input,
26+
string name = "",
27+
TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input,
28+
name,
29+
optimize: true,
30+
out_type: out_type);
1331
}
1432
}

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ public Saver(RefVariable[] var_list = null,
5555
_keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours;
5656
_name = name;
5757
_restore_sequentially = restore_sequentially;
58+
_saver_def = saver_def;
5859
_builder = builder;
5960
_is_built = false;
6061
_allow_empty = allow_empty;
@@ -122,7 +123,7 @@ private void _build(string checkpoint_path, bool build_save, bool build_restore)
122123
}
123124
else if (_saver_def != null && !string.IsNullOrEmpty(_name))
124125
{
125-
throw new NotImplementedException("");
126+
throw new NotImplementedException("Saver._build");
126127
}
127128

128129
_check_saver_def();
@@ -200,6 +201,38 @@ public string save(Session sess,
200201
return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope);
201202
}
202203

204+
/// <summary>
205+
/// Restores previously saved variables.
206+
///
207+
/// This method runs the ops added by the constructor for restoring variables.
208+
/// It requires a session in which the graph was launched. The variables to
209+
/// restore do not have to have been initialized, as restoring is itself a way
210+
/// to initialize variables.
211+
/// </summary>
212+
/// <param name="sess">A `Session` to use to restore the parameters. None in eager mode.</param>
213+
/// <param name="save_path">Path where parameters were previously saved.</param>
214+
public void restore(Session sess, string save_path)
215+
{
216+
if (_is_empty)
217+
return;
218+
219+
if (string.IsNullOrEmpty(save_path))
220+
throw new ValueError("Can't load save_path when it is None.");
221+
222+
if (!checkpoint_management.checkpoint_exists(save_path))
223+
throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}");
224+
225+
Console.WriteLine($"Restoring parameters from {save_path}");
226+
227+
if (tf.context.executing_eagerly())
228+
;
229+
else
230+
sess.run(_saver_def.RestoreOpName, new FeedItem[]
231+
{
232+
new FeedItem(_saver_def.FilenameTensorName, save_path)
233+
});
234+
}
235+
203236
/// <summary>
204237
/// Writes `MetaGraphDef` to save_path/filename.
205238
/// </summary>

src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.IO;
44
using System.Linq;
55
using System.Text;
6+
using static Tensorflow.SaverDef.Types;
67

78
namespace Tensorflow
89
{
@@ -105,5 +106,23 @@ public static string meta_graph_filename(string checkpoint_filename, string meta
105106
string suffixed_filename = basename + "." + meta_graph_suffix;
106107
return suffixed_filename;
107108
}
109+
110+
public static bool checkpoint_exists(string checkpoint_prefix)
111+
{
112+
string pathname = _prefix_to_checkpoint_path(checkpoint_prefix, CheckpointFormatVersion.V2);
113+
if (File.Exists(pathname))
114+
return true;
115+
else if (File.Exists(checkpoint_prefix))
116+
return true;
117+
else
118+
return false;
119+
}
120+
121+
private static string _prefix_to_checkpoint_path(string prefix, CheckpointFormatVersion format_version)
122+
{
123+
if (format_version == CheckpointFormatVersion.V2)
124+
return prefix + ".index";
125+
return prefix;
126+
}
108127
}
109128
}

src/TensorFlowNET.Core/Train/Saving/saver.py.cs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45

56
namespace Tensorflow
@@ -13,25 +14,43 @@ public static (Saver, object) _import_meta_graph_with_return_elements(string met
1314
{
1415
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);
1516

16-
var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements(
17+
var meta = meta_graph.import_scoped_meta_graph_with_return_elements(
1718
meta_graph_def,
1819
clear_devices: clear_devices,
1920
import_scope: import_scope,
2021
return_elements: return_elements);
2122

23+
var (imported_vars, imported_return_elements) = meta;
24+
2225
var saver = _create_saver_from_imported_meta_graph(
2326
meta_graph_def, import_scope, imported_vars);
2427

2528
return (saver, null);
2629
}
2730

31+
/// <summary>
32+
/// Return a saver for restoring variable values to an imported MetaGraph.
33+
/// </summary>
34+
/// <param name="meta_graph_def"></param>
35+
/// <param name="import_scope"></param>
36+
/// <param name="imported_vars"></param>
37+
/// <returns></returns>
2838
public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def,
2939
string import_scope,
30-
(Dictionary<string, RefVariable>, ITensorOrOperation[]) imported_vars)
40+
Dictionary<string, RefVariable> imported_vars)
3141
{
3242
if(meta_graph_def.SaverDef != null)
3343
{
34-
throw new NotImplementedException("_create_saver_from_imported_meta_graph");
44+
// Infer the scope that is prepended by `import_scoped_meta_graph`.
45+
string scope = import_scope;
46+
var var_names = imported_vars.Keys.ToArray();
47+
if(var_names.Length > 0)
48+
{
49+
var sample_key = var_names[0];
50+
var sample_var = imported_vars[sample_key];
51+
scope = string.Join("", sample_var.name.Skip(sample_key.Length));
52+
}
53+
return new Saver(saver_def: meta_graph_def.SaverDef, name: scope);
3554
}
3655
else
3756
{
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow;
5+
6+
namespace TensorFlowNET.Examples
7+
{
8+
public class MetaGraph : Python, IExample
9+
{
10+
public void Run()
11+
{
12+
ImportMetaGraph("my-save-dir/");
13+
}
14+
15+
private void ImportMetaGraph(string dir)
16+
{
17+
with<Session>(tf.Session(), sess =>
18+
{
19+
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta");
20+
new_saver.restore(sess, dir + "my-model-10000");
21+
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
22+
var batch_size = tf.size(labels);
23+
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0];
24+
var loss = tf.losses.sparse_softmax_cross_entropy(labels = labels,
25+
logits = logits);
26+
});
27+
}
28+
}
29+
}

0 commit comments

Comments
 (0)