Skip to content

Commit 060cc37

Browse files
committed
tf.sparse_tensor_to_dense, TensorShape.merge_with SciSharp#396
1 parent acb5505 commit 060cc37

7 files changed

Lines changed: 157 additions & 131 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.sparse.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,20 @@ namespace Tensorflow
2020
{
2121
public partial class tensorflow
2222
{
23-
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, int[] dense_shape)
23+
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, long[] dense_shape)
2424
=> new SparseTensor<T>(indices, values, dense_shape);
2525

26+
public Tensor sparse_tensor_to_dense<T>(SparseTensor<T> sp_input,
27+
T default_value = default,
28+
bool validate_indices = true,
29+
string name = null)
30+
=> gen_sparse_ops.sparse_to_dense(sp_input.indices,
31+
sp_input.dense_shape,
32+
sp_input.values,
33+
default_value: default_value,
34+
validate_indices: validate_indices,
35+
name: name);
36+
2637
/// <summary>
2738
/// Converts a sparse representation into a dense tensor.
2839
/// </summary>

src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using static Tensorflow.Binding;
1+
using System;
2+
using System.Linq;
3+
using static Tensorflow.Binding;
24

35
namespace Tensorflow.Framework
46
{
@@ -8,15 +10,20 @@ namespace Tensorflow.Framework
810
public class SparseTensor<T> : CompositeTensor, _TensorLike
911
{
1012
long[,] _indices;
11-
Tensor indices;
13+
public Tensor indices;
1214

1315
T[] _values;
14-
Tensor values;
16+
public Tensor values;
1517

16-
int[] _dense_shape;
17-
Tensor dense_shape;
18+
long[] _dense_shape;
19+
public Tensor dense_shape;
1820

19-
public SparseTensor(long[,] indices_, T[] values_, int[] dense_shape_)
21+
TensorShape _shape;
22+
public TensorShape shape => _shape;
23+
24+
public TF_DataType dtype => dtypes.as_dtype(typeof(T));
25+
26+
public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
2027
{
2128
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
2229
{
@@ -37,6 +44,8 @@ public SparseTensor(long[,] indices_, T[] values_, int[] dense_shape_)
3744

3845
indices_shape[0].merge_with(values_shape.dims[0]);
3946
indices_shape[1].merge_with(dense_shape_shape.dims[0]);
47+
48+
_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
4049
}
4150
}
4251

src/TensorFlowNET.Core/Operations/gen_sparse_ops.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System.Collections.Generic;
18+
using Tensorflow.Framework;
1819

1920
namespace Tensorflow
2021
{
@@ -50,5 +51,24 @@ public static Tensor sparse_to_dense<T>(Tensor sparse_indices,
5051

5152
return _op.output;
5253
}
54+
55+
public static Tensor sparse_to_dense<T>(Tensor sparse_indices,
56+
Tensor output_shape,
57+
Tensor sparse_values,
58+
T default_value = default,
59+
bool validate_indices = true,
60+
string name = null)
61+
{
62+
var _op = _op_def_lib._apply_op_helper("SparseToDense", name, args: new
63+
{
64+
sparse_indices,
65+
output_shape,
66+
sparse_values,
67+
default_value,
68+
validate_indices
69+
});
70+
71+
return _op.output;
72+
}
5373
}
5474
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class Dimension
8+
{
9+
int _value;
10+
public int value => _value;
11+
12+
public Dimension(int value)
13+
{
14+
_value = value;
15+
}
16+
17+
public Dimension merge_with(Dimension other)
18+
{
19+
if (_value == -1)
20+
return new Dimension(other.value);
21+
else
22+
return new Dimension(_value);
23+
}
24+
25+
public override string ToString() => $"Dimension({_value})";
26+
}
27+
}

src/TensorFlowNET.Core/Tensors/TensorShape.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
using NumSharp;
22
using System;
3+
using System.Collections.Generic;
34
using System.Diagnostics.CodeAnalysis;
45
using System.Linq;
56
using System.Runtime.CompilerServices;
6-
using NumSharp.Utilities;
7+
using static Tensorflow.Binding;
78

89
namespace Tensorflow
910
{
@@ -196,12 +197,26 @@ public TensorShape concatenate(TensorShape other)
196197
}
197198
}
198199

200+
/// <summary>
201+
/// Returns a `TensorShape` combining the information in `self` and `other`.
202+
/// </summary>
203+
/// <param name="other"></param>
204+
/// <returns></returns>
199205
public TensorShape merge_with(TensorShape other)
200206
{
201207
if (dims.Length == 0)
202208
return other;
203209

204-
throw new NotImplementedException("merge_with");
210+
var new_dims = new List<int>();
211+
212+
foreach (var i in range(ndim))
213+
{
214+
var dim = new Dimension(dims[i]);
215+
var merged = dim.merge_with(new Dimension(other.dims[i]));
216+
new_dims.Add(merged.value);
217+
}
218+
219+
return new TensorShape(new_dims.ToArray());
205220
}
206221

207222
/// <summary>

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 51 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -118,110 +118,10 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
118118
if (values == null)
119119
throw new ValueError("None values not supported.");
120120

121-
if(np_dt == null)
122-
{
123-
switch (values)
124-
{
125-
case bool boolVal:
126-
nparray = boolVal;
127-
break;
128-
case int intVal:
129-
nparray = intVal;
130-
break;
131-
case int[] intVals:
132-
nparray = np.array(intVals);
133-
break;
134-
case int[,] intVals:
135-
nparray = np.array(intVals);
136-
break;
137-
case long intVal:
138-
nparray = intVal;
139-
break;
140-
case long[] intVals:
141-
nparray = np.array(intVals);
142-
break;
143-
case long[,] intVals:
144-
nparray = np.array(intVals);
145-
break;
146-
case float floatVal:
147-
nparray = floatVal;
148-
break;
149-
case float[] floatVals:
150-
nparray = floatVals;
151-
break;
152-
case float[,] floatVals:
153-
nparray = np.array(floatVals);
154-
break;
155-
case double doubleVal:
156-
nparray = doubleVal;
157-
break;
158-
case double[] doubleVals:
159-
nparray = np.array(doubleVals);
160-
break;
161-
case double[,] doubleVals:
162-
nparray = np.array(doubleVals);
163-
break;
164-
case string strVal:
165-
nparray = strVal;
166-
break;
167-
case string[] strVals:
168-
nparray = strVals;
169-
break;
170-
case byte[] byteValues:
171-
nparray = byteValues;
172-
break;
173-
case byte[,] byteValues:
174-
nparray = np.array(byteValues);
175-
break;
176-
default:
177-
throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented");
178-
}
179-
}
180-
else
181-
{
182-
// convert data type
183-
switch (np_dt.Name)
184-
{
185-
case "Int32":
186-
if (values.GetType().IsArray)
187-
nparray = np.array((int[])values, np_dt);
188-
else
189-
nparray = Converts.ToInt32(values);
190-
break;
191-
case "Int64":
192-
if (values.GetType().IsArray)
193-
nparray = np.array((int[])values, np_dt);
194-
else
195-
nparray = Converts.ToInt64(values);
196-
break;
197-
case "Single":
198-
if (values.GetType().IsArray)
199-
nparray = np.array((float[])values, np_dt);
200-
else
201-
nparray = Converts.ToSingle(values);
202-
break;
203-
case "Double":
204-
if (values.GetType().IsArray)
205-
nparray = np.array((double[])values, np_dt);
206-
else
207-
nparray = Converts.ToDouble(values);
208-
break;
209-
case "String":
210-
if (values.GetType().IsArray)
211-
nparray = np.array((string[])values, np_dt);
212-
else
213-
nparray = NDArray.FromString(Converts.ToString(values));
214-
break;
215-
case "Boolean":
216-
if (values.GetType().IsArray)
217-
nparray = np.array((bool[])values, np_dt);
218-
else
219-
nparray = Converts.ToBoolean(values);
220-
break;
221-
default:
222-
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
223-
}
224-
}
121+
nparray = convert_to_numpy_ndarray(values);
122+
123+
if (np_dt != null && np_dt != typeof(string))
124+
nparray = nparray.astype(np_dt);
225125
}
226126

227127
var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype);
@@ -316,23 +216,59 @@ public static NDArray convert_to_numpy_ndarray(object values)
316216
case NDArray val:
317217
nd = val;
318218
break;
319-
case int val:
320-
nd = np.asarray(val);
219+
case bool boolVal:
220+
nd = boolVal;
221+
break;
222+
case int intVal:
223+
nd = intVal;
224+
break;
225+
case int[] intVals:
226+
nd = np.array(intVals);
227+
break;
228+
case int[,] intVals:
229+
nd = np.array(intVals);
230+
break;
231+
case long intVal:
232+
nd = intVal;
233+
break;
234+
case long[] intVals:
235+
nd = np.array(intVals);
236+
break;
237+
case long[,] intVals:
238+
nd = np.array(intVals);
239+
break;
240+
case float floatVal:
241+
nd = floatVal;
242+
break;
243+
case float[] floatVals:
244+
nd = floatVals;
245+
break;
246+
case float[,] floatVals:
247+
nd = np.array(floatVals);
248+
break;
249+
case double doubleVal:
250+
nd = doubleVal;
251+
break;
252+
case double[] doubleVals:
253+
nd = np.array(doubleVals);
254+
break;
255+
case double[,] doubleVals:
256+
nd = np.array(doubleVals);
321257
break;
322-
case int[] val:
323-
nd = np.array(val);
258+
case string strVal:
259+
nd = NDArray.FromString(strVal);
324260
break;
325-
case float val:
326-
nd = np.asarray(val);
261+
case string[] strVals:
262+
nd = strVals;
327263
break;
328-
case double val:
329-
nd = np.asarray(val);
264+
case byte[] byteValues:
265+
nd = byteValues;
330266
break;
331-
case string val:
332-
nd = np.asarray(val);
267+
case byte[,] byteValues:
268+
nd = np.array(byteValues);
333269
break;
334270
default:
335-
throw new Exception("Not Implemented");
271+
throw new NotImplementedException($"convert_to_numpy_ndarray: Support for type {values.GetType()} Not Implemented");
336272
}
337273

338274
return nd;

test/TensorFlowNET.UnitTest/TensorTest.cs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,22 @@ public void sparse_to_dense()
225225
[TestMethod]
226226
public void sparse_tensor_to_dense()
227227
{
228-
/*int[,] dense_array =
228+
var decoded_list = tf.SparseTensor(new[,]
229229
{
230-
{ 1, 0, 0, 0, 0 },
231-
{ 0, 1, 0, 0, 0 },
232-
{ 0, 0, 1, 0, 0 },
233-
{ 0, 0, 0, 1, 0 }
234-
};
235-
var sparseTensor = new SparseTensor<int>(indices, values, dense_shape);*/
230+
{ 0L, 0L },
231+
{ 1L, 2L }
232+
},
233+
new int[] { 1, 2 },
234+
new[] { 3L, 4L });
235+
236+
var onehot = tf.sparse_tensor_to_dense(decoded_list);
237+
using (var sess = tf.Session())
238+
{
239+
var result = sess.run(onehot);
240+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
241+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
242+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
243+
}
236244
}
237245
}
238246
}

0 commit comments

Comments
 (0)