Skip to content

Commit 94c8c6b

Browse files
committed
Created TensorConverter
1 parent b7ccf3b commit 94c8c6b

1 file changed

Lines changed: 285 additions & 0 deletions

File tree

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
using System;
2+
using System.Threading.Tasks;
3+
using NumSharp;
4+
using NumSharp.Backends;
5+
using NumSharp.Utilities;
6+
7+
namespace Tensorflow
8+
{
9+
/// <summary>
10+
/// Provides various methods to conversion between types and <see cref="Tensor"/>.
11+
/// </summary>
12+
public static class TensorConverter
13+
{
14+
/// <summary>
15+
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
16+
/// </summary>
17+
/// <param name="nd">The ndarray to convert, can be regular, jagged or multi-dim array.</param>
18+
/// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
19+
/// <exception cref="NotSupportedException"></exception>
20+
public static Tensor ToTensor(NDArray nd, TF_DataType? astype = null)
21+
{
22+
return new Tensor(astype == null ? nd : nd.astype(astype.Value.as_numpy_typecode(), false));
23+
}
24+
25+
/// <summary>
26+
/// Convert given <see cref="NDArray"/> to <see cref="Tensor"/>.
27+
/// </summary>
28+
/// <param name="nd">The ndarray to convert.</param>
29+
/// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
30+
/// <exception cref="NotSupportedException"></exception>
31+
public static Tensor ToTensor(NDArray nd, NPTypeCode? astype = null)
32+
{
33+
return new Tensor(astype == null ? nd : nd.astype(astype.Value, false));
34+
}
35+
36+
/// <summary>
37+
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
38+
/// </summary>
39+
/// <param name="array">The array to convert, can be regular, jagged or multi-dim array.</param>
40+
/// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
41+
/// <exception cref="NotSupportedException"></exception>
42+
public static Tensor ToTensor(Array array, TF_DataType? astype = null)
43+
{
44+
if (array == null) throw new ArgumentNullException(nameof(array));
45+
var arrtype = array.ResolveElementType();
46+
47+
var astype_type = astype?.as_numpy_dtype() ?? arrtype;
48+
if (astype_type == arrtype)
49+
{
50+
//no conversion required
51+
if (astype == TF_DataType.TF_STRING)
52+
{
53+
throw new NotSupportedException(); //TODO! when string is fully implemented.
54+
}
55+
56+
if (astype == TF_DataType.TF_INT8)
57+
{
58+
if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged
59+
array = Arrays.Flatten(array);
60+
61+
return new Tensor((sbyte[]) array);
62+
}
63+
64+
//is multidim or jagged, if so - use NDArrays constructor as it records shape.
65+
if (array.Rank != 1 || array.GetType().GetElementType().IsArray)
66+
return new Tensor(new NDArray(array));
67+
68+
#if _REGEN
69+
#region Compute
70+
switch (arrtype)
71+
{
72+
%foreach supported_dtypes,supported_dtypes_lowercase%
73+
case NPTypeCode.#1: return new Tensor((#2[])arr);
74+
%
75+
default:
76+
throw new NotSupportedException();
77+
}
78+
#endregion
79+
#else
80+
81+
#region Compute
82+
83+
switch (arrtype.GetTypeCode())
84+
{
85+
case NPTypeCode.Boolean: return new Tensor((bool[]) array);
86+
case NPTypeCode.Byte: return new Tensor((byte[]) array);
87+
case NPTypeCode.Int16: return new Tensor((short[]) array);
88+
case NPTypeCode.UInt16: return new Tensor((ushort[]) array);
89+
case NPTypeCode.Int32: return new Tensor((int[]) array);
90+
case NPTypeCode.UInt32: return new Tensor((uint[]) array);
91+
case NPTypeCode.Int64: return new Tensor((long[]) array);
92+
case NPTypeCode.UInt64: return new Tensor((ulong[]) array);
93+
case NPTypeCode.Char: return new Tensor((char[]) array);
94+
case NPTypeCode.Double: return new Tensor((double[]) array);
95+
case NPTypeCode.Single: return new Tensor((float[]) array);
96+
default:
97+
throw new NotSupportedException();
98+
}
99+
100+
#endregion
101+
102+
#endif
103+
} else
104+
{
105+
//conversion is required.
106+
//by this point astype is not null.
107+
108+
//flatten if required
109+
if (array.Rank != 1 || array.GetType().GetElementType()?.IsArray == true) //is multidim or jagged
110+
array = Arrays.Flatten(array);
111+
112+
try
113+
{
114+
return ToTensor(
115+
ArrayConvert.To(array, astype.Value.as_numpy_typecode()),
116+
null
117+
);
118+
} catch (NotSupportedException)
119+
{
120+
//handle dtypes not supported by ArrayConvert
121+
var ret = Array.CreateInstance(astype_type, array.LongLength);
122+
Parallel.For(0, ret.LongLength, i => ret.SetValue(Convert.ChangeType(array.GetValue(i), astype_type), i));
123+
return ToTensor(ret, null);
124+
}
125+
}
126+
}
127+
128+
/// <summary>
129+
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
130+
/// </summary>
131+
/// <param name="constant">The constant scalar to convert</param>
132+
/// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
133+
/// <exception cref="NotSupportedException"></exception>
134+
public static Tensor ToTensor<T>(T constant, TF_DataType? astype = null) where T : unmanaged
135+
{
136+
//was conversion requested?
137+
if (astype == null)
138+
{
139+
//No conversion required
140+
var constantType = typeof(T).as_dtype();
141+
if (constantType == TF_DataType.TF_INT8)
142+
return new Tensor((sbyte) (object) constant);
143+
144+
if (constantType == TF_DataType.TF_STRING)
145+
return new Tensor((string) (object) constant);
146+
147+
#if _REGEN
148+
#region Compute
149+
switch (InfoOf<T>.NPTypeCode)
150+
{
151+
%foreach supported_dtypes,supported_dtypes_lowercase%
152+
case NPTypeCode.#1: return new Tensor((#2)(object)constant);
153+
%
154+
default:
155+
throw new NotSupportedException();
156+
}
157+
#endregion
158+
#else
159+
160+
#region Compute
161+
162+
switch (InfoOf<T>.NPTypeCode)
163+
{
164+
case NPTypeCode.Boolean: return new Tensor((bool) (object) constant);
165+
case NPTypeCode.Byte: return new Tensor((byte) (object) constant);
166+
case NPTypeCode.Int16: return new Tensor((short) (object) constant);
167+
case NPTypeCode.UInt16: return new Tensor((ushort) (object) constant);
168+
case NPTypeCode.Int32: return new Tensor((int) (object) constant);
169+
case NPTypeCode.UInt32: return new Tensor((uint) (object) constant);
170+
case NPTypeCode.Int64: return new Tensor((long) (object) constant);
171+
case NPTypeCode.UInt64: return new Tensor((ulong) (object) constant);
172+
case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant));
173+
case NPTypeCode.Double: return new Tensor((double) (object) constant);
174+
case NPTypeCode.Single: return new Tensor((float) (object) constant);
175+
default:
176+
throw new NotSupportedException();
177+
}
178+
179+
#endregion
180+
#endif
181+
}
182+
183+
//conversion required
184+
185+
if (astype == TF_DataType.TF_INT8)
186+
return new Tensor(Converts.ToSByte(constant));
187+
188+
if (astype == TF_DataType.TF_STRING)
189+
return new Tensor(Converts.ToString(constant));
190+
191+
var astype_np = astype?.as_numpy_typecode();
192+
193+
#if _REGEN
194+
#region Compute
195+
switch (astype_np)
196+
{
197+
%foreach supported_dtypes,supported_dtypes_lowercase%
198+
case NPTypeCode.#1: return new Tensor(Converts.To#1(constant));
199+
%
200+
default:
201+
throw new NotSupportedException();
202+
}
203+
#endregion
204+
#else
205+
206+
#region Compute
207+
switch (astype_np)
208+
{
209+
case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant));
210+
case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant));
211+
case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant));
212+
case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant));
213+
case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant));
214+
case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant));
215+
case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant));
216+
case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant));
217+
case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant));
218+
case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant));
219+
case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant));
220+
default:
221+
throw new NotSupportedException();
222+
}
223+
#endregion
224+
#endif
225+
}
226+
227+
/// <summary>
228+
/// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
229+
/// </summary>
230+
/// <param name="constant">The constant scalar to convert</param>
231+
/// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
232+
/// <exception cref="NotSupportedException"></exception>
233+
public static Tensor ToTensor(string constant, TF_DataType? astype = null)
234+
{
235+
switch (astype)
236+
{
237+
//was conversion requested?
238+
case null:
239+
case TF_DataType.TF_STRING:
240+
return new Tensor(constant);
241+
//conversion required
242+
case TF_DataType.TF_INT8:
243+
return new Tensor(Converts.ToSByte(constant));
244+
default:
245+
{
246+
var astype_np = astype?.as_numpy_typecode();
247+
248+
#if _REGEN
249+
#region Compute
250+
switch (astype_np)
251+
{
252+
%foreach supported_dtypes,supported_dtypes_lowercase%
253+
case NPTypeCode.#1: return new Tensor(Converts.To#1(constant));
254+
%
255+
default:
256+
throw new NotSupportedException();
257+
}
258+
#endregion
259+
#else
260+
261+
#region Compute
262+
switch (astype_np)
263+
{
264+
case NPTypeCode.Boolean: return new Tensor(Converts.ToBoolean(constant));
265+
case NPTypeCode.Byte: return new Tensor(Converts.ToByte(constant));
266+
case NPTypeCode.Int16: return new Tensor(Converts.ToInt16(constant));
267+
case NPTypeCode.UInt16: return new Tensor(Converts.ToUInt16(constant));
268+
case NPTypeCode.Int32: return new Tensor(Converts.ToInt32(constant));
269+
case NPTypeCode.UInt32: return new Tensor(Converts.ToUInt32(constant));
270+
case NPTypeCode.Int64: return new Tensor(Converts.ToInt64(constant));
271+
case NPTypeCode.UInt64: return new Tensor(Converts.ToUInt64(constant));
272+
case NPTypeCode.Char: return new Tensor(Converts.ToByte(constant));
273+
case NPTypeCode.Double: return new Tensor(Converts.ToDouble(constant));
274+
case NPTypeCode.Single: return new Tensor(Converts.ToSingle(constant));
275+
default:
276+
throw new NotSupportedException();
277+
}
278+
#endregion
279+
#endif
280+
}
281+
}
282+
}
283+
284+
}
285+
}

0 commit comments

Comments
 (0)