Skip to content

Commit b252a07

Browse files
committed
BaseSession: Added transparent dtype conversion to feed_dict
1 parent 94c8c6b commit b252a07

1 file changed

Lines changed: 43 additions & 36 deletions

File tree

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
107107
foreach (var subfeed in feed_dict)
108108
{
109109
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
110-
//var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used
110+
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used
111111
feed_dict_tensor[subfeed_t] = subfeed.Value;
112112
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
113113
}
@@ -150,58 +150,64 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
150150
int i = 0;
151151
foreach (var x in feed_dict)
152152
{
153-
if (x.Key is Tensor tensor)
153+
if (x.Key is Tensor key)
154154
{
155155
switch (x.Value)
156156
{
157157
case Tensor v:
158-
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
158+
if (v.dtype != key.dtype)
159+
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}");
160+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
159161
break;
160162
case NDArray v:
161-
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
163+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
162164
break;
163165
case IntPtr v:
164-
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
166+
var tensor = new Tensor(v);
167+
if (tensor.dtype != key.dtype)
168+
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");
169+
170+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
165171
break;
166172
#if _REGEN
167173
// @formatter:off — disable formatter after this line
168-
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
169-
%foreach types%
170-
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
171-
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
172-
%
174+
%types = ["bool", "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
175+
%foreach types%
176+
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
177+
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
178+
%
173179
// @formatter:on — enable formatter after this line
174180
#else
175181
// @formatter:off — disable formatter after this line
176-
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
177-
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
178-
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
179-
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
180-
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
181-
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
182-
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
183-
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
184-
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
185-
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
186-
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
187-
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
188-
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
189-
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
190-
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
191-
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
192-
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
193-
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
194-
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
195-
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
196-
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
197-
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
182+
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
183+
case bool[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
184+
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
185+
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
186+
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
187+
case byte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
188+
case short v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
189+
case short[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
190+
case ushort v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
191+
case ushort[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
192+
case int v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
193+
case int[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
194+
case uint v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
195+
case uint[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
196+
case long v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
197+
case long[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
198+
case ulong v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
199+
case ulong[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
200+
case float v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
201+
case float[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
202+
case double v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
203+
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
204+
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
205+
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype)); break;
198206
// @formatter:on — enable formatter after this line
199207
#endif
200-
case bool v:
201-
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL));
202-
break;
208+
203209
case string v:
204-
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
210+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), TensorConverter.ToTensor(v, key.dtype));
205211
break;
206212
default:
207213
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
@@ -214,6 +220,7 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
214220
return _call_tf_sessionrun(feeds, fetches, target_list);
215221
}
216222

223+
217224
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
218225
{
219226
// Ensure any changes to the graph are reflected in the runtime.

0 commit comments

Comments
 (0)