Skip to content

Commit be996fa

Browse files
committed
Tensor.Creation: Revamp of CreateTensorFromArray to properly handle TF_NewTensor
- Added other missing AllocationType setting in different cases.
1 parent a539b9f commit be996fa

1 file changed

Lines changed: 27 additions & 5 deletions

File tree

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,14 +440,15 @@ public unsafe Tensor(Complex value, TF_DataType? dType = null)
440440
#endif
441441

442442
/// <summary>
443-
/// Create a string Tensor from the given string
443+
/// Create a string Tensor from the given string
444444
/// </summary>
445445
public unsafe Tensor(string str)
446446
{
447447
var status = new Status();
448448
var buffer = Encoding.UTF8.GetBytes(str);
449449
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
450450
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
451+
AllocationType = AllocationType.Tensorflow;
451452

452453
IntPtr tensor = c_api.TF_TensorData(handle);
453454
Marshal.WriteInt64(tensor, 0);
@@ -459,6 +460,9 @@ public unsafe Tensor(string str)
459460

460461
public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
461462
{
463+
if (tensorDType == null)
464+
tensorDType = nd.dtype.as_dtype();
465+
462466
// todo: handle nd of type "String" here too
463467
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
464468
{
@@ -467,6 +471,7 @@ public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
467471
var bytesLength = (UIntPtr) nd.size;
468472
var size = c_api.TF_StringEncodedSize(bytesLength);
469473
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
474+
AllocationType = AllocationType.Tensorflow;
470475

471476
IntPtr tensor = c_api.TF_TensorData(handle);
472477
Marshal.WriteInt64(tensor, 0);
@@ -481,6 +486,7 @@ public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
481486
var buffer = nd.ToArray<byte>();
482487
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
483488
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
489+
AllocationType = AllocationType.Tensorflow;
484490

485491
IntPtr tensor = c_api.TF_TensorData(handle);
486492
Marshal.WriteInt64(tensor, 0);
@@ -535,6 +541,7 @@ public unsafe Tensor(byte[][] buffer, long[] shape)
535541
int totalSize = size + buffer.Length * 8;
536542
ulong offset = 0;
537543
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize);
544+
AllocationType = AllocationType.Tensorflow;
538545

539546
// Clear offset table
540547
IntPtr pOffset = TF_TensorData(handle);
@@ -626,12 +633,27 @@ protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data,
626633

627634
// get a handle to the pinned array which we will pass on to the tensor computation engine to use
628635
var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned);
629-
AllocationType = AllocationType.GCHandle;
630-
AllocationHandle = gcHandle;
636+
var pinnedAddr = gcHandle.AddrOfPinnedObject();
631637

638+
//call NewTensor
639+
IntPtr handle;
632640
if (shape == null || shape.Length == 0)
633-
return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr) (count * element_size));
634-
return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr) (count * element_size));
641+
handle = TF_NewTensor(dt, new long[0], 0, pinnedAddr + start * element_size, (UIntPtr) (count * element_size));
642+
else
643+
handle = TF_NewTensor(dt, shape, shape.Length, pinnedAddr + start * element_size, (UIntPtr) (count * element_size));
644+
645+
//Figure if TF decided to clone or not.
646+
if (c_api.TF_TensorData(handle) == pinnedAddr)
647+
{
648+
AllocationType = AllocationType.GCHandle;
649+
AllocationHandle = gcHandle;
650+
} else
651+
{
652+
AllocationType = AllocationType.Tensorflow;
653+
gcHandle.Free();
654+
}
655+
656+
return handle;
635657
}
636658
}
637659
}

0 commit comments

Comments
 (0)