@@ -21,6 +21,9 @@ limitations under the License.
2121using System . Linq ;
2222using System . Numerics ;
2323using System . Text ;
24+ using Google . Protobuf ;
25+ using NumSharp . Backends ;
26+ using Tensorflow . Util ;
2427
2528namespace Tensorflow
2629{
@@ -246,111 +249,167 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] f
246249 return result ;
247250 }
248251
249- private unsafe NDArray fetchValue ( IntPtr output )
252+ private static unsafe NDArray fetchValue ( IntPtr output )
250253 {
251- var tensor = new Tensor ( output ) ;
252- NDArray nd = null ;
253- Type type = tensor . dtype . as_numpy_dtype ( ) ;
254- var ndims = tensor . shape ;
255- var offset = ( byte * ) c_api . TF_TensorData ( output ) ;
256-
257- if ( ndims . Length == 0 )
254+ NDArray ret ;
255+ using ( var tensor = new Tensor ( output ) )
258256 {
259- switch ( tensor . dtype )
257+ var ndims = tensor . shape ;
258+ var srcAddress = c_api . TF_TensorData ( output ) . ToInt64 ( ) ;
259+
260+ if ( ndims . Length == 0 )
260261 {
261- case TF_DataType . TF_BOOL :
262- nd = NDArray . Scalar ( * ( bool * ) offset ) ;
263- break ;
264- case TF_DataType . TF_STRING :
265- var bytes = tensor . BufferToArray ( ) ;
266- // wired, don't know why we have to start from offset 9.
267- // length in the begin
268- var str = UTF8Encoding . Default . GetString ( bytes , 9 , bytes [ 8 ] ) ;
269- nd = NDArray . FromString ( str ) ;
270- break ;
271- case TF_DataType . TF_UINT8 :
272- nd = NDArray . Scalar ( * ( byte * ) offset ) ;
273- break ;
274- case TF_DataType . TF_INT16 :
275- nd = NDArray . Scalar ( * ( short * ) offset ) ;
276- break ;
277- case TF_DataType . TF_INT32 :
278- nd = NDArray . Scalar ( * ( int * ) offset ) ;
279- break ;
280- case TF_DataType . TF_INT64 :
281- nd = NDArray . Scalar ( * ( long * ) offset ) ;
282- break ;
283- case TF_DataType . TF_FLOAT :
284- nd = NDArray . Scalar ( * ( float * ) offset ) ;
285- break ;
286- case TF_DataType . TF_DOUBLE :
287- nd = NDArray . Scalar ( * ( double * ) offset ) ;
288- break ;
289- default :
290- throw new NotImplementedException ( "can't fetch output" ) ;
291- }
292- }
293- else
294- {
295- switch ( tensor . dtype )
262+ switch ( tensor . dtype )
263+ {
264+ case TF_DataType . TF_BOOL :
265+ ret = NDArray . Scalar ( * ( bool * ) srcAddress ) ;
266+ break ;
267+ case TF_DataType . TF_STRING :
268+ using ( var reader = new CodedInputStream ( new IntPtr ( srcAddress ) . Stream ( 8 , ( long ) tensor . bytesize ) ) )
269+ ret = NDArray . FromString ( reader . ReadString ( ) ) ;
270+ break ;
271+ case TF_DataType . TF_UINT8 :
272+ ret = NDArray . Scalar ( * ( byte * ) srcAddress ) ;
273+ break ;
274+ case TF_DataType . TF_INT16 :
275+ ret = NDArray . Scalar ( * ( short * ) srcAddress ) ;
276+ break ;
277+ case TF_DataType . TF_INT32 :
278+ ret = NDArray . Scalar ( * ( int * ) srcAddress ) ;
279+ break ;
280+ case TF_DataType . TF_INT64 :
281+ ret = NDArray . Scalar ( * ( long * ) srcAddress ) ;
282+ break ;
283+ case TF_DataType . TF_UINT16 :
284+ ret = NDArray . Scalar ( * ( ushort * ) srcAddress ) ;
285+ break ;
286+ case TF_DataType . TF_UINT32 :
287+ ret = NDArray . Scalar ( * ( uint * ) srcAddress ) ;
288+ break ;
289+ case TF_DataType . TF_UINT64 :
290+ ret = NDArray . Scalar ( * ( ulong * ) srcAddress ) ;
291+ break ;
292+ case TF_DataType . TF_FLOAT :
293+ ret = NDArray . Scalar ( * ( float * ) srcAddress ) ;
294+ break ;
295+ case TF_DataType . TF_DOUBLE :
296+ ret = NDArray . Scalar ( * ( double * ) srcAddress ) ;
297+ break ;
298+ default :
299+ throw new NotImplementedException ( "can't fetch output" ) ;
300+ }
301+ } else
296302 {
297- case TF_DataType . TF_BOOL :
298- var bools = new bool [ tensor . size ] ;
299- for ( ulong i = 0 ; i < tensor . size ; i ++ )
300- bools [ i ] = * ( bool * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
301- nd = np . array ( bools ) . reshape ( ndims ) ;
302- break ;
303- case TF_DataType . TF_STRING :
304- var bytes = tensor . BufferToArray ( ) ;
305- // wired, don't know why we have to start from offset 9.
306- // length in the begin
307- var str = UTF8Encoding . Default . GetString ( bytes , 9 , bytes [ 8 ] ) ;
308- nd = np . array ( str ) ;
309- break ;
310- case TF_DataType . TF_UINT8 :
311- var _bytes = new byte [ tensor . size ] ;
312- for ( ulong i = 0 ; i < tensor . size ; i ++ )
313- _bytes [ i ] = * ( byte * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
314- nd = np . array ( _bytes ) . reshape ( ndims ) ;
315- break ;
316- case TF_DataType . TF_INT16 :
317- var shorts = new short [ tensor . size ] ;
318- for ( ulong i = 0 ; i < tensor . size ; i ++ )
319- shorts [ i ] = * ( short * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
320- nd = np . array ( shorts ) . reshape ( ndims ) ;
321- break ;
322- case TF_DataType . TF_INT32 :
323- var ints = new int [ tensor . size ] ;
324- for ( ulong i = 0 ; i < tensor . size ; i ++ )
325- ints [ i ] = * ( int * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
326- nd = np . array ( ints ) . reshape ( ndims ) ;
327- break ;
328- case TF_DataType . TF_INT64 :
329- var longs = new long [ tensor . size ] ;
330- for ( ulong i = 0 ; i < tensor . size ; i ++ )
331- longs [ i ] = * ( long * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
332- nd = np . array ( longs ) . reshape ( ndims ) ;
333- break ;
334- case TF_DataType . TF_FLOAT :
335- var floats = new float [ tensor . size ] ;
336- for ( ulong i = 0 ; i < tensor . size ; i ++ )
337- floats [ i ] = * ( float * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
338- nd = np . array ( floats ) . reshape ( ndims ) ;
339- break ;
340- case TF_DataType . TF_DOUBLE :
341- var doubles = new double [ tensor . size ] ;
342- for ( ulong i = 0 ; i < tensor . size ; i ++ )
343- doubles [ i ] = * ( double * ) ( offset + ( int ) ( tensor . itemsize * i ) ) ;
344- nd = np . array ( doubles ) . reshape ( ndims ) ;
345- break ;
346- default :
347- throw new NotImplementedException ( "can't fetch output" ) ;
303+ //var size = (long) tensor.size;
304+ //var itemsize = (long) tensor.itemsize;
305+ var bytesize = ( long ) tensor . bytesize ;
306+ var src = ( void * ) srcAddress ;
307+
308+ #if _REGEN
309+ #region Compute
310+ switch ( tensor . dtype )
311+ {
312+ % foreach except( supported_dtypes , "Char" ) , except ( supported_dtypes_lowercase , "char" ) , except ( supported_dtypes_TF_DataType , "TF_STRING" ) %
313+ case TF_DataType. #3 :
314+ {
315+ ret = new NDArray ( NPTypeCode . #1 , ndims , false ) ;
316+ System. Buffer . MemoryCopy ( src , #( #3 == "TF_STRING" | "(byte*)ret.Unsafe.Address + 8" | "ret.Unsafe.Address" ) , bytesize , bytesize ) ;
317+ break ;
318+ }
319+ %
320+ case TF_DataType. TF_STRING:
321+ {
322+ //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
323+ using ( var reader = new CodedInputStream ( new IntPtr ( srcAddress ) . Stream ( 8 , ( long ) tensor . bytesize ) ) )
324+ ret = NDArray. FromString( reader. ReadString( ) ) ;
325+ break ;
326+ }
327+ default :
328+ throw new NotSupportedException ( ) ;
329+ }
330+ #endregion
331+ #else
332+
333+ #region Compute
334+ switch ( tensor . dtype )
335+ {
336+ case TF_DataType. TF_BOOL :
337+ {
338+ ret = new NDArray ( NPTypeCode . Boolean , ndims , false ) ;
339+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
340+ break ;
341+ }
342+ case TF_DataType. TF_UINT8 :
343+ {
344+ ret = new NDArray ( NPTypeCode . Byte , ndims , false ) ;
345+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
346+ break ;
347+ }
348+ case TF_DataType. TF_INT16 :
349+ {
350+ ret = new NDArray ( NPTypeCode . Int16 , ndims , false ) ;
351+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
352+ break ;
353+ }
354+ case TF_DataType. TF_UINT16 :
355+ {
356+ ret = new NDArray ( NPTypeCode . UInt16 , ndims , false ) ;
357+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
358+ break ;
359+ }
360+ case TF_DataType. TF_INT32 :
361+ {
362+ ret = new NDArray ( NPTypeCode . Int32 , ndims , false ) ;
363+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
364+ break ;
365+ }
366+ case TF_DataType . TF_UINT32 :
367+ {
368+ ret = new NDArray ( NPTypeCode . UInt32 , ndims , false ) ;
369+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
370+ break ;
371+ }
372+ case TF_DataType . TF_INT64 :
373+ {
374+ ret = new NDArray ( NPTypeCode . Int64 , ndims , false ) ;
375+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
376+ break ;
377+ }
378+ case TF_DataType . TF_UINT64 :
379+ {
380+ ret = new NDArray ( NPTypeCode . UInt64 , ndims , false ) ;
381+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
382+ break ;
383+ }
384+ case TF_DataType . TF_DOUBLE :
385+ {
386+ ret = new NDArray ( NPTypeCode . Double , ndims , false ) ;
387+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
388+ break ;
389+ }
390+ case TF_DataType . TF_FLOAT :
391+ {
392+ ret = new NDArray ( NPTypeCode . Single , ndims , false ) ;
393+ System. Buffer . MemoryCopy ( src , ret . Unsafe . Address , bytesize , bytesize ) ;
394+ break ;
395+ }
396+ case TF_DataType . TF_STRING :
397+ {
398+ throw new NotImplementedException ( ) ;
399+ //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
400+ using ( var reader = new CodedInputStream ( new IntPtr ( srcAddress ) . Stream ( 8 , ( long ) tensor . bytesize ) ) )
401+ ret = NDArray. FromString ( reader . ReadString ( ) ) ;
402+ break ;
403+ }
404+ default:
405+ throw new NotSupportedException( ) ;
406+ }
407+ #endregion
408+ #endif
348409 }
349410 }
350-
351- tensor . Dispose ( ) ;
352411
353- return nd ;
412+ return ret ;
354413 }
355414
356415 /// <summary>
0 commit comments