Skip to content

Commit 2192f4d

Browse files
committed
fix placeholder feeds value issue.
1 parent 4c51d1b commit 2192f4d

12 files changed

Lines changed: 132 additions & 507 deletions

File tree

src/TensorFlowNET.Console/MemoryMonitor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public void WarmUp()
1515
while (true)
1616
{
1717
var ones = np.ones((128, 128));
18+
Thread.Sleep(1);
1819
}
1920

2021
TensorShape shape = (1, 32, 32, 3);

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,5 +505,40 @@ public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey k
505505

506506
return defaultValue;
507507
}
508+
509+
public static Shape GetShape(this object data)
510+
{
511+
if (!data.GetType().IsArray)
512+
return Shape.Scalar;
513+
514+
switch (data)
515+
{
516+
case Array array:
517+
var dims = range(array.Rank).Select(x => (long)array.GetLength(x)).ToArray();
518+
return new Shape(dims);
519+
default:
520+
throw new NotImplementedException("");
521+
}
522+
}
523+
524+
public static unsafe byte[] ToByteArray(Array array)
525+
{
526+
/*var size = array.GetShape().size;
527+
byte[]? bytes = null;
528+
switch (array)
529+
{
530+
case float[] arr:
531+
var len = new byte[size * sizeof(float)];
532+
fixed (void* addr = &arr[0])
533+
System.Buffer.MemoryCopy(addr, dst, bytesize, bytesize);
534+
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(array.ToArray());
535+
break;
536+
default:
537+
throw new NotImplementedException("");
538+
}
539+
540+
return bytes;*/
541+
throw new NotImplementedException("");
542+
}
508543
}
509544
}

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ public partial class NDArray
1414
public ulong dtypesize => _tensor.itemsize;
1515
public int ndim => _tensor.NDims;
1616
public long[] dims => _tensor.dims.Select(x => Convert.ToInt64(x)).ToArray();
17-
public Shape shape => _tensor.shape;
17+
public Shape shape => _tensor.shape;
18+
public IntPtr data => _tensor.TensorDataPointer;
1819

1920
public NDArray(bool value)
2021
{

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ object max_(object x, object y)
923923
int p_height = (int)max_(0, math_ops.cast(f_padding_height, dtype: dtypes.int32));
924924
int p_width = (int)max_(0, math_ops.cast(f_padding_width, dtype: dtypes.int32));
925925

926-
var resized = resize_fn(image, new Tensor(new[] { resized_height, resized_width }));
926+
var resized = resize_fn(image, array_ops.concat(new[] { resized_height, resized_width }, 0));
927927

928928
var padded = pad_to_bounding_box(resized, p_height, p_width, target_height,
929929
target_width);

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,35 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
176176
var tensor = new Tensor(v);
177177
if (tensor.dtype != key.dtype)
178178
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");
179-
180179
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
181180
break;
182-
default:
183-
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), constant_op.constant(x.Value));
181+
case bool v:
182+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
183+
break;
184+
case byte v:
185+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
186+
break;
187+
case int v:
188+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
189+
break;
190+
case long v:
191+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
184192
break;
193+
case float v:
194+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
195+
break;
196+
case double v:
197+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
198+
break;
199+
case Array v:
200+
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape()));
201+
break;
202+
default:
203+
throw new NotImplementedException("");
185204
}
186205
}
206+
else
207+
throw new NotImplementedException("");
187208
}
188209

189210
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();

0 commit comments

Comments
 (0)