Skip to content

Commit 4ca080e

Browse files
committed
Hello World works.
1 parent 9ed2dd5 commit 4ca080e

6 files changed

Lines changed: 18 additions & 9 deletions

File tree

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ private unsafe NDArray fetchValue(IntPtr output)
302302
// wired, don't know why we have to start from offset 9.
303303
// length in the begin
304304
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
305-
nd = np.array(str).reshape();
305+
nd = np.array(str);
306306
break;
307307
case TF_DataType.TF_UINT8:
308308
var _bytes = new byte[tensor.size];

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ Docs: https://tensorflownet.readthedocs.io</Description>
6363

6464
<ItemGroup>
6565
<PackageReference Include="Google.Protobuf" Version="3.9.0" />
66-
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="4.5.2" />
6766
</ItemGroup>
6867

6968
<ItemGroup>

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ public static Type as_numpy_datatype(this TF_DataType type)
5151
}
5252

5353
// "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"
54-
public static TF_DataType as_dtype(Type type)
54+
public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null)
5555
{
56-
TF_DataType dtype = TF_DataType.DtInvalid;
57-
5856
switch (type.Name)
5957
{
58+
case "Char":
59+
dtype = dtype ?? TF_DataType.TF_UINT8;
60+
break;
6061
case "SByte":
6162
dtype = TF_DataType.TF_INT8;
6263
break;
@@ -100,7 +101,7 @@ public static TF_DataType as_dtype(Type type)
100101
throw new Exception("as_dtype Not Implemented");
101102
}
102103

103-
return dtype;
104+
return dtype.Value;
104105
}
105106

106107
public static DataType as_datatype_enum(this TF_DataType type)

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
226226
}
227227
}
228228

229-
var numpy_dtype = dtypes.as_dtype(nparray.dtype);
229+
var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype);
230230
if (numpy_dtype == TF_DataType.DtInvalid)
231231
throw new TypeError($"Unrecognized data type: {nparray.dtype}");
232232

src/TensorFlowNET.Core/Tensors/tf.constant.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ public static Tensor constant(object value,
3333
verify_shape: verify_shape,
3434
allow_broadcast: false);
3535

36+
public static Tensor constant(string value,
37+
string name = "Const") => constant_op._constant_impl(value,
38+
tf.@string,
39+
new int[] { 1 },
40+
name,
41+
verify_shape: false,
42+
allow_broadcast: false);
43+
3644
public static Tensor constant(float value,
3745
int shape,
3846
string name = "Const") => constant_op._constant_impl(value,

test/TensorFlowNET.Examples/HelloWorld.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ of the Constant op. */
2929
{
3030
// Run the op
3131
var result = sess.run(hello);
32-
Console.WriteLine(result.ToString());
33-
return result.ToString().Equals(str);
32+
string result_string = string.Join("", result.GetData<char>());
33+
Console.WriteLine(result_string);
34+
return result_string.Equals(str);
3435
});
3536
}
3637

0 commit comments

Comments
 (0)