Skip to content

Commit 8ae2feb

Browse files
committed
Malformed TF_STRING tensor; element 0 out of range
1 parent 5720dfd commit 8ae2feb

8 files changed

Lines changed: 131 additions & 94 deletions

File tree

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tenso
7676
obj = temp_obj;
7777

7878
// If obj appears to be a name...
79-
if (obj is String str)
79+
if (obj is string name)
8080
{
81-
if(str.Contains(":") && allow_tensor)
81+
if(name.Contains(":") && allow_tensor)
8282
{
83-
string op_name = str.Split(':')[0];
84-
int out_n = int.Parse(str.Split(':')[1]);
83+
string op_name = name.Split(':')[0];
84+
int out_n = int.Parse(name.Split(':')[1]);
8585

8686
if (_nodes_by_name.ContainsKey(op_name))
8787
return _nodes_by_name[op_name].outputs[out_n];

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ private NDArray _run(object fetches, FeedItem[] feed_dict = null)
6767
default:
6868
throw new NotImplementedException("_run subfeed");
6969
}
70-
feed_map[subfeed_t.name] = new Tuple<object, object>(subfeed_t, subfeed.Value);
70+
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
7171
}
7272
}
7373

@@ -178,7 +178,8 @@ private unsafe NDArray fetchValue(IntPtr output)
178178
case TF_DataType.TF_STRING:
179179
var bytes = tensor.Data();
180180
// wired, don't know why we have to start from offset 9.
181-
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
181+
// length in the begin
182+
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
182183
nd = np.array(str).reshape();
183184
break;
184185
case TF_DataType.TF_INT16:
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
using NumSharp.Core;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Runtime.InteropServices;
6+
using System.Text;
7+
using static Tensorflow.c_api;
8+
9+
namespace Tensorflow
10+
{
11+
public partial class Tensor
12+
{
13+
/// <summary>
14+
/// if original buffer is free.
15+
/// </summary>
16+
private bool deallocator_called;
17+
18+
public Tensor(IntPtr handle)
19+
{
20+
_handle = handle;
21+
}
22+
23+
public Tensor(NDArray nd)
24+
{
25+
_handle = Allocate(nd);
26+
}
27+
28+
private IntPtr Allocate(NDArray nd)
29+
{
30+
IntPtr dotHandle = IntPtr.Zero;
31+
ulong size = 0;
32+
33+
if (nd.dtype.Name != "String")
34+
{
35+
dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
36+
size = (ulong)(nd.size * nd.dtypesize);
37+
}
38+
39+
switch (nd.dtype.Name)
40+
{
41+
case "Int16":
42+
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
43+
break;
44+
case "Int32":
45+
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
46+
break;
47+
case "Single":
48+
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
49+
break;
50+
case "Double":
51+
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
52+
break;
53+
case "String":
54+
/*var value = nd.Data<string>()[0];
55+
var bytes = Encoding.UTF8.GetBytes(value);
56+
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1);
57+
Marshal.Copy(bytes, 0, dotHandle, bytes.Length);
58+
size = (ulong)bytes.Length;*/
59+
60+
var str = nd.Data<string>()[0];
61+
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
62+
//dotHandle = Marshal.AllocHGlobal((int)dst_len);
63+
//size = c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);
64+
65+
var dataType1 = ToTFDataType(nd.dtype);
66+
// shape
67+
var dims1 = nd.shape.Select(x => (long)x).ToArray();
68+
69+
var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
70+
dims1,
71+
nd.ndim,
72+
dst_len);
73+
74+
dotHandle = c_api.TF_TensorData(tfHandle1);
75+
c_api.TF_StringEncode(str, (ulong)str.Length, dotHandle, dst_len, status);
76+
return tfHandle1;
77+
break;
78+
default:
79+
throw new NotImplementedException("Marshal.Copy failed.");
80+
}
81+
82+
var dataType = ToTFDataType(nd.dtype);
83+
// shape
84+
var dims = nd.shape.Select(x => (long)x).ToArray();
85+
// Free the original buffer and set flag
86+
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) =>
87+
{
88+
Marshal.FreeHGlobal(dotHandle);
89+
closure = true;
90+
};
91+
92+
var tfHandle = c_api.TF_NewTensor(dataType,
93+
dims,
94+
nd.ndim,
95+
dotHandle,
96+
size,
97+
deallocator,
98+
ref deallocator_called);
99+
100+
return tfHandle;
101+
}
102+
103+
public Tensor(Operation op, int value_index, TF_DataType dtype)
104+
{
105+
this.op = op;
106+
this.value_index = value_index;
107+
this._dtype = dtype;
108+
_id = ops.uid();
109+
}
110+
}
111+
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -95,86 +95,6 @@ public int rank
9595

9696
public int NDims => rank;
9797

98-
/// <summary>
99-
/// if original buffer is free.
100-
/// </summary>
101-
private bool deallocator_called;
102-
103-
public Tensor(IntPtr handle)
104-
{
105-
_handle = handle;
106-
}
107-
108-
public Tensor(NDArray nd)
109-
{
110-
_handle = Allocate(nd);
111-
}
112-
113-
private IntPtr Allocate(NDArray nd)
114-
{
115-
IntPtr dotHandle = IntPtr.Zero;
116-
ulong size = 0;
117-
118-
if (nd.dtype.Name != "String")
119-
{
120-
dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size);
121-
size = (ulong)(nd.size * nd.dtypesize);
122-
}
123-
124-
switch (nd.dtype.Name)
125-
{
126-
case "Int16":
127-
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size);
128-
break;
129-
case "Int32":
130-
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size);
131-
break;
132-
case "Single":
133-
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size);
134-
break;
135-
case "Double":
136-
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
137-
break;
138-
case "String":
139-
var value = nd.Data<string>()[0];
140-
var bytes = Encoding.UTF8.GetBytes(value);
141-
dotHandle = Marshal.AllocHGlobal(bytes.Length + 1);
142-
Marshal.Copy(bytes, 0, dotHandle, bytes.Length);
143-
size = (ulong)bytes.Length;
144-
break;
145-
default:
146-
throw new NotImplementedException("Marshal.Copy failed.");
147-
}
148-
149-
var dataType = ToTFDataType(nd.dtype);
150-
// shape
151-
var dims = nd.shape.Select(x => (long)x).ToArray();
152-
// Free the original buffer and set flag
153-
Deallocator deallocator = (IntPtr values, IntPtr len, ref bool closure) =>
154-
{
155-
Marshal.FreeHGlobal(dotHandle);
156-
closure = true;
157-
};
158-
159-
var tfHandle = c_api.TF_NewTensor(dataType,
160-
dims,
161-
nd.ndim,
162-
dotHandle,
163-
size,
164-
deallocator,
165-
ref deallocator_called);
166-
167-
return tfHandle;
168-
}
169-
170-
public Tensor(Operation op, int value_index, TF_DataType dtype)
171-
{
172-
this.op = op;
173-
this.value_index = value_index;
174-
this._dtype = dtype;
175-
_id = ops.uid();
176-
}
177-
17898
public Operation[] Consumers => consumers();
17999

180100
public string Device => op.Device;

src/TensorFlowNET.Core/Tensors/c_api.tensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public partial class c_api
120120
/// <param name="status">TF_Status*</param>
121121
/// <returns>On success returns the size in bytes of the encoded string.</returns>
122122
[DllImport(TensorFlowLibName)]
123-
public static extern ulong TF_StringEncode(string src, ulong src_len, string dst, ulong dst_len, IntPtr status);
123+
public static extern ulong TF_StringEncode(string src, ulong src_len, IntPtr dst, ulong dst_len, IntPtr status);
124124

125125
/// <summary>
126126
/// Decode a string encoded using TF_StringEncode.
@@ -132,6 +132,6 @@ public partial class c_api
132132
/// <param name="status">TF_Status*</param>
133133
/// <returns></returns>
134134
[DllImport(TensorFlowLibName)]
135-
public static extern ulong TF_StringDecode(string src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status);
135+
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status);
136136
}
137137
}

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,14 @@ private void _check_saver_def()
138138
public string save(Session sess,
139139
string save_path,
140140
string global_step = "",
141+
string latest_filename = "",
141142
string meta_graph_suffix = "meta",
142143
bool write_meta_graph = true,
143144
bool write_state = true,
144145
bool strip_default_attrs = false)
145146
{
146-
string latest_filename = "checkpoint";
147+
if (string.IsNullOrEmpty(latest_filename))
148+
latest_filename = "checkpoint";
147149
string model_checkpoint_path = "";
148150
string checkpoint_file = "";
149151

test/TensorFlowNET.UnitTest/ConstantTest.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Collections.Generic;
55
using System.Linq;
6+
using System.Runtime.InteropServices;
67
using System.Text;
78
using Tensorflow;
89

@@ -104,11 +105,14 @@ public void StringEncode()
104105
string str = "Hello, TensorFlow.NET!";
105106
ulong dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
106107
Assert.AreEqual(dst_len, (ulong)23);
107-
string dst = "";
108-
c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status);
108+
IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
109+
ulong encoded_len = c_api.TF_StringEncode(str, (ulong)str.Length, dst, dst_len, status);
110+
Assert.AreEqual((ulong)23, encoded_len);
109111
Assert.AreEqual(status.Code, TF_Code.TF_OK);
110-
111-
//c_api.TF_StringDecode(str, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
112+
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
113+
Assert.AreEqual(encoded_str, str);
114+
Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
115+
//c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
112116
}
113117

114118
/// <summary>

test/TensorFlowNET.UnitTest/TrainSaverTest.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ public void Save1()
4545
});
4646
}
4747

48-
[TestMethod]
4948
public void Save2()
5049
{
5150
var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer);

0 commit comments

Comments
 (0)