forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmap_fn.cs
More file actions
86 lines (74 loc) · 3.13 KB
/
map_fn.cs
File metadata and controls
86 lines (74 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.Binding;
namespace Tensorflow
{
public partial class Operation
{
/// <summary>
/// map on the list of tensors unpacked from `elems` on dimension 0.
/// </summary>
/// <param name="fn"></param>
/// <param name="elems"></param>
/// <param name="dtype"></param>
/// <param name="parallel_iterations"></param>
/// <param name="back_prop"></param>
/// <param name="swap_memory"></param>
/// <param name="infer_shape"></param>
/// <param name="name"></param>
/// <returns>A tensor or (possibly nested) sequence of tensors.</returns>
public static Tensor map_fn(Func<Tensor, Tensor> fn,
Tensor elems,
TF_DataType dtype = TF_DataType.DtInvalid,
int parallel_iterations = 10,
bool back_prop = true,
bool swap_memory = false,
bool infer_shape = true,
string name = null)
{
var elems_flat = new[] { elems };
tf_with(ops.name_scope(name, "map", elems_flat), delegate
{
var varscope = tf.get_variable_scope();
elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem"))
.ToArray();
dtype = elems_flat.Select(elem => elem.dtype).First();
var dtype_flat = new[] { dtype };
// Convert elems to tensor array. n may be known statically.
var static_shape = elems_flat[0].shape;
var n = static_shape[0];
// TensorArrays are always flat
var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype,
size: ops.convert_to_tensor(n),
dynamic_size: false,
infer_shape: true)).ToArray();
// Unpack elements
var elems_ta_1 = new List<TensorArray>();
foreach (var (elem_ta, elem) in zip(elems_ta, elems_flat))
elems_ta_1.Add(elem_ta.unstack(elem));
elems_ta = elems_ta_1.ToArray();
var i = constant_op.constant(0);
var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt,
size: ops.convert_to_tensor(n),
dynamic_size: false,
infer_shape: infer_shape)).ToArray();
/*Func<Tensor, TensorArray> compute = (i, tas) =>
{
throw new NotImplementedException("");
};
var r_a = control_flow_ops.while_loop(
(i, _) => i < n,
compute,
new[] { i, accs_ta },
parallel_iterations: parallel_iterations,
back_prop: back_prop,
swap_memory: swap_memory,
maximum_iterations: n);*/
});
throw new NotImplementedException("");
}
}
}