1+ using System ;
2+ using System . Threading . Tasks ;
3+ using NumSharp ;
4+ using NumSharp . Backends ;
5+ using NumSharp . Utilities ;
6+
7+ namespace Tensorflow
8+ {
9+ /// <summary>
10+ /// Provides various methods to conversion between types and <see cref="Tensor"/>.
11+ /// </summary>
12+ public static class TensorConverter
13+ {
14+ /// <summary>
15+ /// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
16+ /// </summary>
17+ /// <param name="nd">The ndarray to convert, can be regular, jagged or multi-dim array.</param>
18+ /// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
19+ /// <exception cref="NotSupportedException"></exception>
20+ public static Tensor ToTensor ( NDArray nd , TF_DataType ? astype = null )
21+ {
22+ return new Tensor ( astype == null ? nd : nd . astype ( astype . Value . as_numpy_typecode ( ) , false ) ) ;
23+ }
24+
25+ /// <summary>
26+ /// Convert given <see cref="NDArray"/> to <see cref="Tensor"/>.
27+ /// </summary>
28+ /// <param name="nd">The ndarray to convert.</param>
29+ /// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
30+ /// <exception cref="NotSupportedException"></exception>
31+ public static Tensor ToTensor ( NDArray nd , NPTypeCode ? astype = null )
32+ {
33+ return new Tensor ( astype == null ? nd : nd . astype ( astype . Value , false ) ) ;
34+ }
35+
36+ /// <summary>
37+ /// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
38+ /// </summary>
39+ /// <param name="array">The array to convert, can be regular, jagged or multi-dim array.</param>
40+ /// <param name="astype">Convert <see cref="Array"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
41+ /// <exception cref="NotSupportedException"></exception>
42+ public static Tensor ToTensor ( Array array , TF_DataType ? astype = null )
43+ {
44+ if ( array == null ) throw new ArgumentNullException ( nameof ( array ) ) ;
45+ var arrtype = array . ResolveElementType ( ) ;
46+
47+ var astype_type = astype ? . as_numpy_dtype ( ) ?? arrtype ;
48+ if ( astype_type == arrtype )
49+ {
50+ //no conversion required
51+ if ( astype == TF_DataType . TF_STRING )
52+ {
53+ throw new NotSupportedException ( ) ; //TODO! when string is fully implemented.
54+ }
55+
56+ if ( astype == TF_DataType . TF_INT8 )
57+ {
58+ if ( array . Rank != 1 || array . GetType ( ) . GetElementType ( ) ? . IsArray == true ) //is multidim or jagged
59+ array = Arrays . Flatten ( array ) ;
60+
61+ return new Tensor ( ( sbyte [ ] ) array ) ;
62+ }
63+
64+ //is multidim or jagged, if so - use NDArrays constructor as it records shape.
65+ if ( array . Rank != 1 || array . GetType ( ) . GetElementType ( ) . IsArray )
66+ return new Tensor ( new NDArray ( array ) ) ;
67+
68+ #if _REGEN
69+ #region Compute
70+ switch ( arrtype )
71+ {
72+ % foreach supported_dtypes, supported_dtypes_lowercase%
73+ case NPTypeCode . #1 : return new Tensor ( ( #2 [ ] ) arr) ;
74+ %
75+ default :
76+ throw new NotSupportedException ( ) ;
77+ }
78+ #endregion
79+ #else
80+
81+ #region Compute
82+
83+ switch ( arrtype . GetTypeCode ( ) )
84+ {
85+ case NPTypeCode . Boolean : return new Tensor ( ( bool [ ] ) array ) ;
86+ case NPTypeCode . Byte : return new Tensor ( ( byte [ ] ) array ) ;
87+ case NPTypeCode . Int16 : return new Tensor ( ( short [ ] ) array ) ;
88+ case NPTypeCode . UInt16 : return new Tensor ( ( ushort [ ] ) array ) ;
89+ case NPTypeCode . Int32 : return new Tensor ( ( int [ ] ) array ) ;
90+ case NPTypeCode . UInt32 : return new Tensor ( ( uint [ ] ) array ) ;
91+ case NPTypeCode . Int64 : return new Tensor ( ( long [ ] ) array ) ;
92+ case NPTypeCode . UInt64 : return new Tensor ( ( ulong [ ] ) array ) ;
93+ case NPTypeCode . Char : return new Tensor ( ( char [ ] ) array ) ;
94+ case NPTypeCode . Double : return new Tensor ( ( double [ ] ) array ) ;
95+ case NPTypeCode . Single : return new Tensor ( ( float [ ] ) array ) ;
96+ default :
97+ throw new NotSupportedException ( ) ;
98+ }
99+
100+ #endregion
101+
102+ #endif
103+ } else
104+ {
105+ //conversion is required.
106+ //by this point astype is not null.
107+
108+ //flatten if required
109+ if ( array . Rank != 1 || array . GetType ( ) . GetElementType ( ) ? . IsArray == true ) //is multidim or jagged
110+ array = Arrays . Flatten ( array ) ;
111+
112+ try
113+ {
114+ return ToTensor (
115+ ArrayConvert . To ( array , astype . Value . as_numpy_typecode ( ) ) ,
116+ null
117+ ) ;
118+ } catch ( NotSupportedException )
119+ {
120+ //handle dtypes not supported by ArrayConvert
121+ var ret = Array . CreateInstance ( astype_type , array . LongLength ) ;
122+ Parallel . For ( 0 , ret . LongLength , i => ret . SetValue ( Convert . ChangeType ( array . GetValue ( i ) , astype_type ) , i ) ) ;
123+ return ToTensor ( ret , null ) ;
124+ }
125+ }
126+ }
127+
128+ /// <summary>
129+ /// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
130+ /// </summary>
131+ /// <param name="constant">The constant scalar to convert</param>
132+ /// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
133+ /// <exception cref="NotSupportedException"></exception>
134+ public static Tensor ToTensor < T > ( T constant , TF_DataType ? astype = null ) where T : unmanaged
135+ {
136+ //was conversion requested?
137+ if ( astype == null )
138+ {
139+ //No conversion required
140+ var constantType = typeof ( T ) . as_dtype ( ) ;
141+ if ( constantType == TF_DataType . TF_INT8 )
142+ return new Tensor ( ( sbyte ) ( object ) constant ) ;
143+
144+ if ( constantType == TF_DataType . TF_STRING )
145+ return new Tensor ( ( string ) ( object ) constant ) ;
146+
147+ #if _REGEN
148+ #region Compute
149+ switch ( InfoOf < T > . NPTypeCode )
150+ {
151+ % foreach supported_dtypes, supported_dtypes_lowercase%
152+ case NPTypeCode . #1 : return new Tensor ( ( #2 ) ( object ) constant ) ;
153+ %
154+ default :
155+ throw new NotSupportedException ( ) ;
156+ }
157+ #endregion
158+ #else
159+
160+ #region Compute
161+
162+ switch ( InfoOf < T > . NPTypeCode )
163+ {
164+ case NPTypeCode . Boolean : return new Tensor ( ( bool ) ( object ) constant ) ;
165+ case NPTypeCode . Byte : return new Tensor ( ( byte ) ( object ) constant ) ;
166+ case NPTypeCode . Int16 : return new Tensor ( ( short ) ( object ) constant ) ;
167+ case NPTypeCode . UInt16 : return new Tensor ( ( ushort ) ( object ) constant ) ;
168+ case NPTypeCode . Int32 : return new Tensor ( ( int ) ( object ) constant ) ;
169+ case NPTypeCode . UInt32 : return new Tensor ( ( uint ) ( object ) constant ) ;
170+ case NPTypeCode . Int64 : return new Tensor ( ( long ) ( object ) constant ) ;
171+ case NPTypeCode . UInt64 : return new Tensor ( ( ulong ) ( object ) constant ) ;
172+ case NPTypeCode . Char : return new Tensor ( Converts . ToByte ( constant ) ) ;
173+ case NPTypeCode . Double : return new Tensor ( ( double ) ( object ) constant ) ;
174+ case NPTypeCode . Single : return new Tensor ( ( float ) ( object ) constant ) ;
175+ default :
176+ throw new NotSupportedException ( ) ;
177+ }
178+
179+ #endregion
180+ #endif
181+ }
182+
183+ //conversion required
184+
185+ if ( astype == TF_DataType . TF_INT8 )
186+ return new Tensor ( Converts . ToSByte ( constant ) ) ;
187+
188+ if ( astype == TF_DataType . TF_STRING )
189+ return new Tensor ( Converts . ToString ( constant ) ) ;
190+
191+ var astype_np = astype ? . as_numpy_typecode ( ) ;
192+
193+ #if _REGEN
194+ #region Compute
195+ switch ( astype_np )
196+ {
197+ % foreach supported_dtypes, supported_dtypes_lowercase%
198+ case NPTypeCode . #1 : return new Tensor ( Converts . To #1 ( constant ) ) ;
199+ %
200+ default :
201+ throw new NotSupportedException ( ) ;
202+ }
203+ #endregion
204+ #else
205+
206+ #region Compute
207+ switch ( astype_np )
208+ {
209+ case NPTypeCode . Boolean : return new Tensor ( Converts . ToBoolean ( constant ) ) ;
210+ case NPTypeCode . Byte : return new Tensor ( Converts . ToByte ( constant ) ) ;
211+ case NPTypeCode . Int16 : return new Tensor ( Converts . ToInt16 ( constant ) ) ;
212+ case NPTypeCode . UInt16 : return new Tensor ( Converts . ToUInt16 ( constant ) ) ;
213+ case NPTypeCode . Int32 : return new Tensor ( Converts . ToInt32 ( constant ) ) ;
214+ case NPTypeCode . UInt32 : return new Tensor ( Converts . ToUInt32 ( constant ) ) ;
215+ case NPTypeCode . Int64 : return new Tensor ( Converts . ToInt64 ( constant ) ) ;
216+ case NPTypeCode . UInt64 : return new Tensor ( Converts . ToUInt64 ( constant ) ) ;
217+ case NPTypeCode . Char : return new Tensor ( Converts . ToByte ( constant ) ) ;
218+ case NPTypeCode . Double : return new Tensor ( Converts . ToDouble ( constant ) ) ;
219+ case NPTypeCode . Single : return new Tensor ( Converts . ToSingle ( constant ) ) ;
220+ default :
221+ throw new NotSupportedException ( ) ;
222+ }
223+ #endregion
224+ #endif
225+ }
226+
227+ /// <summary>
228+ /// Convert given <see cref="Array"/> to <see cref="Tensor"/>.
229+ /// </summary>
230+ /// <param name="constant">The constant scalar to convert</param>
231+ /// <param name="astype">Convert <paramref name="constant"/> to given <paramref name="astype"/> before inserting it into a <see cref="Tensor"/>.</param>
232+ /// <exception cref="NotSupportedException"></exception>
233+ public static Tensor ToTensor ( string constant , TF_DataType ? astype = null )
234+ {
235+ switch ( astype )
236+ {
237+ //was conversion requested?
238+ case null :
239+ case TF_DataType . TF_STRING :
240+ return new Tensor ( constant ) ;
241+ //conversion required
242+ case TF_DataType . TF_INT8 :
243+ return new Tensor ( Converts . ToSByte ( constant ) ) ;
244+ default :
245+ {
246+ var astype_np = astype ? . as_numpy_typecode ( ) ;
247+
248+ #if _REGEN
249+ #region Compute
250+ switch ( astype_np )
251+ {
252+ % foreach supported_dtypes, supported_dtypes_lowercase%
253+ case NPTypeCode . #1 : return new Tensor ( Converts . To #1 ( constant ) ) ;
254+ %
255+ default :
256+ throw new NotSupportedException ( ) ;
257+ }
258+ #endregion
259+ #else
260+
261+ #region Compute
262+ switch ( astype_np )
263+ {
264+ case NPTypeCode . Boolean : return new Tensor ( Converts . ToBoolean ( constant ) ) ;
265+ case NPTypeCode . Byte : return new Tensor ( Converts . ToByte ( constant ) ) ;
266+ case NPTypeCode . Int16 : return new Tensor ( Converts . ToInt16 ( constant ) ) ;
267+ case NPTypeCode . UInt16 : return new Tensor ( Converts . ToUInt16 ( constant ) ) ;
268+ case NPTypeCode . Int32 : return new Tensor ( Converts . ToInt32 ( constant ) ) ;
269+ case NPTypeCode . UInt32 : return new Tensor ( Converts . ToUInt32 ( constant ) ) ;
270+ case NPTypeCode . Int64 : return new Tensor ( Converts . ToInt64 ( constant ) ) ;
271+ case NPTypeCode . UInt64 : return new Tensor ( Converts . ToUInt64 ( constant ) ) ;
272+ case NPTypeCode . Char : return new Tensor ( Converts . ToByte ( constant ) ) ;
273+ case NPTypeCode . Double : return new Tensor ( Converts . ToDouble ( constant ) ) ;
274+ case NPTypeCode . Single : return new Tensor ( Converts . ToSingle ( constant ) ) ;
275+ default :
276+ throw new NotSupportedException ( ) ;
277+ }
278+ #endregion
279+ #endif
280+ }
281+ }
282+ }
283+
284+ }
285+ }
0 commit comments