forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTensor.String.cs
More file actions
84 lines (72 loc) · 2.88 KB
/
Tensor.String.cs
File metadata and controls
84 lines (72 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
using System;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using static Tensorflow.Binding;
namespace Tensorflow
{
public partial class Tensor
{
const int TF_TSRING_SIZE = 24;
public IntPtr StringTensor(string[] strings, TensorShape shape)
{
// convert string array to byte[][]
var buffer = new byte[strings.Length][];
for (var i = 0; i < strings.Length; i++)
buffer[i] = Encoding.UTF8.GetBytes(strings[i]);
return StringTensor(buffer, shape);
}
public IntPtr StringTensor(byte[][] buffer, TensorShape shape)
{
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(),
shape.ndim,
(ulong)shape.size * TF_TSRING_SIZE);
var tstr = c_api.TF_TensorData(handle);
#if TRACK_TENSOR_LIFE
print($"New TString 0x{handle.ToString("x16")} {AllocationType} Data: 0x{tstr.ToString("x16")}");
#endif
for (int i = 0; i < buffer.Length; i++)
{
c_api.TF_StringInit(tstr);
c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length);
var data = c_api.TF_StringGetDataPointer(tstr);
tstr += TF_TSRING_SIZE;
}
return handle;
}
public string[] StringData()
{
var buffer = StringBytes();
var _str = new string[buffer.Length];
for (int i = 0; i < _str.Length; i++)
_str[i] = Encoding.UTF8.GetString(buffer[i]);
return _str;
}
public byte[][] StringBytes()
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
//
int size = 1;
foreach (var s in TensorShape.dims)
size *= s;
var buffer = new byte[size][];
var tstrings = TensorDataPointer;
for (int i = 0; i < buffer.Length; i++)
{
var data = c_api.TF_StringGetDataPointer(tstrings);
var len = c_api.TF_StringGetSize(tstrings);
buffer[i] = new byte[len];
// var capacity = c_api.TF_StringGetCapacity(tstrings);
// var type = c_api.TF_StringGetType(tstrings);
Marshal.Copy(data, buffer[i], 0, Convert.ToInt32(len));
tstrings += TF_TSRING_SIZE;
}
return buffer;
}
}
}