forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTensor.Value.cs
More file actions
75 lines (65 loc) · 2.17 KB
/
Tensor.Value.cs
File metadata and controls
75 lines (65 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
using Tensorflow.NumPy;
using System;
using System.Text;
using static Tensorflow.Binding;
namespace Tensorflow
{
public partial class Tensor
{
/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public virtual unsafe T[] ToArray<T>() where T : unmanaged
{
//Are the types matching?
if (typeof(T).as_tf_dtype() != dtype)
throw new ArrayTypeMismatchException($"Required dtype {dtype} mismatch with {typeof(T).as_tf_dtype()}.");
if (ndim == 0 && size == 1) //is it a scalar?
{
unsafe
{
return new T[] { *(T*)buffer };
}
}
//types match, no need to perform cast
var ret = new T[size];
var len = (long)(size * dtypesize);
var src = (T*)buffer;
fixed (T* dst = ret)
System.Buffer.MemoryCopy(src, dst, len, len);
return ret;
}
/// <summary>
/// Copy of the contents of this Tensor into a NumPy array or scalar.
/// </summary>
/// <returns>
/// A NumPy array of the same shape and dtype or a NumPy scalar, if this
/// Tensor has rank 0.
/// </returns>
public NDArray numpy()
=> GetNDArray(dtype);
protected NDArray GetNDArray(TF_DataType dtype)
{
if (dtype == TF_DataType.TF_STRING)
{
var str= StringData();
return new NDArray(str, shape);
}
return new NDArray(this, clone: true);
}
/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
public unsafe byte[] BufferToArray()
{
// ReSharper disable once LocalVariableHidesMember
var data = new byte[bytesize];
fixed (byte* dst = data)
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);
return data;
}
}
}