Skip to content

Commit e56164d

Browse files
committed
BaseSession: revamped fetchValue (perf-op)
1 parent 621ffff commit e56164d

2 files changed

Lines changed: 252 additions & 149 deletions

File tree

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 157 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ limitations under the License.
2121
using System.Linq;
2222
using System.Numerics;
2323
using System.Text;
24+
using Google.Protobuf;
25+
using NumSharp.Backends;
26+
using Tensorflow.Util;
2427

2528
namespace Tensorflow
2629
{
@@ -246,111 +249,167 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] f
246249
return result;
247250
}
248251

249-
private unsafe NDArray fetchValue(IntPtr output)
252+
private static unsafe NDArray fetchValue(IntPtr output)
250253
{
251-
var tensor = new Tensor(output);
252-
NDArray nd = null;
253-
Type type = tensor.dtype.as_numpy_dtype();
254-
var ndims = tensor.shape;
255-
var offset = (byte*) c_api.TF_TensorData(output);
256-
257-
if(ndims.Length == 0)
254+
NDArray ret;
255+
using (var tensor = new Tensor(output))
258256
{
259-
switch (tensor.dtype)
257+
var ndims = tensor.shape;
258+
var srcAddress = c_api.TF_TensorData(output).ToInt64();
259+
260+
if (ndims.Length == 0)
260261
{
261-
case TF_DataType.TF_BOOL:
262-
nd = NDArray.Scalar(*(bool*)offset);
263-
break;
264-
case TF_DataType.TF_STRING:
265-
var bytes = tensor.BufferToArray();
266-
// wired, don't know why we have to start from offset 9.
267-
// length in the begin
268-
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
269-
nd = NDArray.FromString(str);
270-
break;
271-
case TF_DataType.TF_UINT8:
272-
nd = NDArray.Scalar(*(byte*)offset);
273-
break;
274-
case TF_DataType.TF_INT16:
275-
nd = NDArray.Scalar(*(short*)offset);
276-
break;
277-
case TF_DataType.TF_INT32:
278-
nd = NDArray.Scalar(*(int*)offset);
279-
break;
280-
case TF_DataType.TF_INT64:
281-
nd = NDArray.Scalar(*(long*)offset);
282-
break;
283-
case TF_DataType.TF_FLOAT:
284-
nd = NDArray.Scalar(*(float*)offset);
285-
break;
286-
case TF_DataType.TF_DOUBLE:
287-
nd = NDArray.Scalar(*(double*)offset);
288-
break;
289-
default:
290-
throw new NotImplementedException("can't fetch output");
291-
}
292-
}
293-
else
294-
{
295-
switch (tensor.dtype)
262+
switch (tensor.dtype)
263+
{
264+
case TF_DataType.TF_BOOL:
265+
ret = NDArray.Scalar(*(bool*) srcAddress);
266+
break;
267+
case TF_DataType.TF_STRING:
268+
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
269+
ret = NDArray.FromString(reader.ReadString());
270+
break;
271+
case TF_DataType.TF_UINT8:
272+
ret = NDArray.Scalar(*(byte*) srcAddress);
273+
break;
274+
case TF_DataType.TF_INT16:
275+
ret = NDArray.Scalar(*(short*) srcAddress);
276+
break;
277+
case TF_DataType.TF_INT32:
278+
ret = NDArray.Scalar(*(int*) srcAddress);
279+
break;
280+
case TF_DataType.TF_INT64:
281+
ret = NDArray.Scalar(*(long*) srcAddress);
282+
break;
283+
case TF_DataType.TF_UINT16:
284+
ret = NDArray.Scalar(*(ushort*) srcAddress);
285+
break;
286+
case TF_DataType.TF_UINT32:
287+
ret = NDArray.Scalar(*(uint*) srcAddress);
288+
break;
289+
case TF_DataType.TF_UINT64:
290+
ret = NDArray.Scalar(*(ulong*) srcAddress);
291+
break;
292+
case TF_DataType.TF_FLOAT:
293+
ret = NDArray.Scalar(*(float*) srcAddress);
294+
break;
295+
case TF_DataType.TF_DOUBLE:
296+
ret = NDArray.Scalar(*(double*) srcAddress);
297+
break;
298+
default:
299+
throw new NotImplementedException("can't fetch output");
300+
}
301+
} else
296302
{
297-
case TF_DataType.TF_BOOL:
298-
var bools = new bool[tensor.size];
299-
for (ulong i = 0; i < tensor.size; i++)
300-
bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
301-
nd = np.array(bools).reshape(ndims);
302-
break;
303-
case TF_DataType.TF_STRING:
304-
var bytes = tensor.BufferToArray();
305-
// wired, don't know why we have to start from offset 9.
306-
// length in the begin
307-
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
308-
nd = np.array(str);
309-
break;
310-
case TF_DataType.TF_UINT8:
311-
var _bytes = new byte[tensor.size];
312-
for (ulong i = 0; i < tensor.size; i++)
313-
_bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
314-
nd = np.array(_bytes).reshape(ndims);
315-
break;
316-
case TF_DataType.TF_INT16:
317-
var shorts = new short[tensor.size];
318-
for (ulong i = 0; i < tensor.size; i++)
319-
shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
320-
nd = np.array(shorts).reshape(ndims);
321-
break;
322-
case TF_DataType.TF_INT32:
323-
var ints = new int[tensor.size];
324-
for (ulong i = 0; i < tensor.size; i++)
325-
ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
326-
nd = np.array(ints).reshape(ndims);
327-
break;
328-
case TF_DataType.TF_INT64:
329-
var longs = new long[tensor.size];
330-
for (ulong i = 0; i < tensor.size; i++)
331-
longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
332-
nd = np.array(longs).reshape(ndims);
333-
break;
334-
case TF_DataType.TF_FLOAT:
335-
var floats = new float[tensor.size];
336-
for (ulong i = 0; i < tensor.size; i++)
337-
floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
338-
nd = np.array(floats).reshape(ndims);
339-
break;
340-
case TF_DataType.TF_DOUBLE:
341-
var doubles = new double[tensor.size];
342-
for (ulong i = 0; i < tensor.size; i++)
343-
doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
344-
nd = np.array(doubles).reshape(ndims);
345-
break;
346-
default:
347-
throw new NotImplementedException("can't fetch output");
303+
//var size = (long) tensor.size;
304+
//var itemsize = (long) tensor.itemsize;
305+
var bytesize = (long) tensor.bytesize;
306+
var src = (void*) srcAddress;
307+
308+
#if _REGEN
309+
#region Compute
310+
switch (tensor.dtype)
311+
{
312+
%foreach except(supported_dtypes, "Char"),except(supported_dtypes_lowercase, "char"),except(supported_dtypes_TF_DataType,"TF_STRING")%
313+
case TF_DataType.#3:
314+
{
315+
ret = new NDArray(NPTypeCode.#1, ndims, false);
316+
System.Buffer.MemoryCopy(src, #(#3=="TF_STRING"|"(byte*)ret.Unsafe.Address + 8"|"ret.Unsafe.Address"), bytesize, bytesize);
317+
break;
318+
}
319+
%
320+
case TF_DataType.TF_STRING:
321+
{
322+
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
323+
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
324+
ret = NDArray.FromString(reader.ReadString());
325+
break;
326+
}
327+
default:
328+
throw new NotSupportedException();
329+
}
330+
#endregion
331+
#else
332+
333+
#region Compute
334+
switch (tensor.dtype)
335+
{
336+
case TF_DataType.TF_BOOL:
337+
{
338+
ret = new NDArray(NPTypeCode.Boolean, ndims, false);
339+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
340+
break;
341+
}
342+
case TF_DataType.TF_UINT8:
343+
{
344+
ret = new NDArray(NPTypeCode.Byte, ndims, false);
345+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
346+
break;
347+
}
348+
case TF_DataType.TF_INT16:
349+
{
350+
ret = new NDArray(NPTypeCode.Int16, ndims, false);
351+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
352+
break;
353+
}
354+
case TF_DataType.TF_UINT16:
355+
{
356+
ret = new NDArray(NPTypeCode.UInt16, ndims, false);
357+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
358+
break;
359+
}
360+
case TF_DataType.TF_INT32:
361+
{
362+
ret = new NDArray(NPTypeCode.Int32, ndims, false);
363+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
364+
break;
365+
}
366+
case TF_DataType.TF_UINT32:
367+
{
368+
ret = new NDArray(NPTypeCode.UInt32, ndims, false);
369+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
370+
break;
371+
}
372+
case TF_DataType.TF_INT64:
373+
{
374+
ret = new NDArray(NPTypeCode.Int64, ndims, false);
375+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
376+
break;
377+
}
378+
case TF_DataType.TF_UINT64:
379+
{
380+
ret = new NDArray(NPTypeCode.UInt64, ndims, false);
381+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
382+
break;
383+
}
384+
case TF_DataType.TF_DOUBLE:
385+
{
386+
ret = new NDArray(NPTypeCode.Double, ndims, false);
387+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
388+
break;
389+
}
390+
case TF_DataType.TF_FLOAT:
391+
{
392+
ret = new NDArray(NPTypeCode.Single, ndims, false);
393+
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
394+
break;
395+
}
396+
case TF_DataType.TF_STRING:
397+
{
398+
throw new NotImplementedException();
399+
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
400+
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
401+
ret = NDArray.FromString(reader.ReadString());
402+
break;
403+
}
404+
default:
405+
throw new NotSupportedException();
406+
}
407+
#endregion
408+
#endif
348409
}
349410
}
350-
351-
tensor.Dispose();
352411

353-
return nd;
412+
return ret;
354413
}
355414

356415
/// <summary>

0 commit comments

Comments
 (0)