Skip to content

Commit 16ea351

Browse files
committed
fix tf.cond.
1 parent 7ec1422 commit 16ea351

24 files changed

Lines changed: 234 additions & 77 deletions

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,32 @@ namespace Tensorflow
2121
{
2222
public partial class tensorflow
2323
{
24+
public IoApi io { get; } = new IoApi();
25+
26+
public class IoApi
27+
{
28+
io_ops ops;
29+
public IoApi()
30+
{
31+
ops = new io_ops();
32+
}
33+
34+
public Tensor read_file(string filename, string name = null)
35+
=> ops.read_file(filename, name);
36+
37+
public Tensor read_file(Tensor filename, string name = null)
38+
=> ops.read_file(filename, name);
39+
40+
public Operation save_v2(Tensor prefix, string[] tensor_names,
41+
string[] shape_and_slices, Tensor[] tensors, string name = null)
42+
=> ops.save_v2(prefix, tensor_names, shape_and_slices, tensors, name: name);
43+
44+
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
45+
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
46+
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);
47+
}
48+
2449
public GFile gfile = new GFile();
25-
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
26-
public Tensor read_file(Tensor filename, string name = null) => gen_io_ops.read_file(filename, name);
2750

2851
public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
2952
Dictionary<string, Tensor> input_map = null,

src/TensorFlowNET.Core/APIs/tf.strings.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,28 @@ namespace Tensorflow
2121
{
2222
public partial class tensorflow
2323
{
24-
public strings_internal strings = new strings_internal();
25-
public class strings_internal
24+
public StringsApi strings { get; } = new StringsApi();
25+
26+
public class StringsApi
2627
{
28+
string_ops ops = new string_ops();
29+
30+
/// <summary>
31+
/// Return substrings from `Tensor` of strings.
32+
/// </summary>
33+
/// <param name="input"></param>
34+
/// <param name="pos"></param>
35+
/// <param name="len"></param>
36+
/// <param name="name"></param>
37+
/// <param name="uint"></param>
38+
/// <returns></returns>
2739
public Tensor substr(Tensor input, int pos, int len,
2840
string name = null, string @uint = "BYTE")
29-
=> string_ops.substr(input, pos, len, name: name, @uint: @uint);
41+
=> ops.substr(input, pos, len, @uint: @uint, name: name);
42+
43+
public Tensor substr(string input, int pos, int len,
44+
string name = null, string @uint = "BYTE")
45+
=> ops.substr(input, pos, len, @uint: @uint, name: name);
3046
}
3147
}
3248
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public Tensor[] TFE_ExecuteCancelable(Context ctx,
4747
status.Check(true);
4848
}
4949
}
50-
if (status.ok())
50+
if (status.ok() && attrs != null)
5151
SetOpAttrs(op, attrs);
5252

5353
var outputs = new IntPtr[num_outputs];

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,6 @@ bool AddInputToOp(object inputs,
204204
input_handle = input.EagerTensorHandle;
205205
flattened_inputs.Add(input);
206206
break;
207-
case EagerTensor[] input_list:
208-
input_handle = input_list[0].EagerTensorHandle;
209-
break;
210207
default:
211208
var tensor = tf.convert_to_tensor(inputs);
212209
input_handle = (tensor as EagerTensor).EagerTensorHandle;

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,16 @@ public static Tensor cond(Tensor pred,
376376
{
377377
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
378378
{
379+
if (tf.context.executing_eagerly())
380+
{
381+
if (pred.ToArray<bool>()[0])
382+
return true_fn() as Tensor;
383+
else
384+
return false_fn() as Tensor;
385+
386+
return null;
387+
}
388+
379389
// Add the Switch to the graph.
380390
var switch_result= @switch(pred, pred);
381391
var (p_2, p_1 )= (switch_result[0], switch_result[1]);
@@ -450,6 +460,16 @@ public static Tensor[] cond<T>(Tensor pred,
450460
{
451461
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
452462
{
463+
if (tf.context.executing_eagerly())
464+
{
465+
if (pred.ToArray<bool>()[0])
466+
return true_fn() as Tensor[];
467+
else
468+
return false_fn() as Tensor[];
469+
470+
return null;
471+
}
472+
453473
// Add the Switch to the graph.
454474
var switch_result = @switch(pred, pred);
455475
var p_2 = switch_result[0];

src/TensorFlowNET.Core/Operations/gen_string_ops.cs

Lines changed: 0 additions & 40 deletions
This file was deleted.

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using System.Linq;
1920
using System.Text;
2021
using Tensorflow.Operations;
2122
using static Tensorflow.Binding;
@@ -63,7 +64,7 @@ public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType
6364
Func<ITensorOrOperation> _bmp = () =>
6465
{
6566
int bmp_channels = channels;
66-
var signature = string_ops.substr(contents, 0, 2);
67+
var signature = tf.strings.substr(contents, 0, 2);
6768
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
6869
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
6970
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
@@ -98,7 +99,7 @@ public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType
9899

99100
return tf_with(ops.name_scope(name, "decode_image"), scope =>
100101
{
101-
substr = string_ops.substr(contents, 0, 3);
102+
substr = tf.strings.substr(contents, 0, 3);
102103
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
103104
});
104105
}
@@ -128,16 +129,19 @@ public static Tensor is_jpeg(Tensor contents, string name = null)
128129
{
129130
return tf_with(ops.name_scope(name, "is_jpeg"), scope =>
130131
{
131-
var substr = string_ops.substr(contents, 0, 3);
132-
return math_ops.equal(substr, "\xff\xd8\xff", name: name);
132+
var substr = tf.strings.substr(contents, 0, 3);
133+
var jpg = Encoding.UTF8.GetString(new byte[] { 0xff, 0xd8, 0xff });
134+
var jpg_tensor = tf.constant(jpg);
135+
var result = math_ops.equal(substr, jpg_tensor, name: name);
136+
return result;
133137
});
134138
}
135139

136140
public static Tensor _is_png(Tensor contents, string name = null)
137141
{
138142
return tf_with(ops.name_scope(name, "is_png"), scope =>
139143
{
140-
var substr = string_ops.substr(contents, 0, 3);
144+
var substr = tf.strings.substr(contents, 0, 3);
141145
return math_ops.equal(substr, @"\211PN", name: name);
142146
});
143147
}

src/TensorFlowNET.Core/Operations/gen_io_ops.cs renamed to src/TensorFlowNET.Core/Operations/io_ops.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,45 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Eager;
1718
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow
2021
{
21-
public class gen_io_ops
22+
public class io_ops
2223
{
23-
public static Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
24+
public Operation save_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name = null)
2425
{
2526
var _op = tf._op_def_lib._apply_op_helper("SaveV2", name: name, args: new { prefix, tensor_names, shape_and_slices, tensors });
2627

2728
return _op;
2829
}
2930

30-
public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
31+
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
3132
{
3233
var _op = tf._op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes });
3334

3435
return _op.outputs;
3536
}
3637

37-
public static Tensor read_file<T>(T filename, string name = null)
38+
public Tensor read_file<T>(T filename, string name = null)
3839
{
40+
if (tf.context.executing_eagerly())
41+
{
42+
return read_file_eager_fallback(filename, name: name, tf.context);
43+
}
44+
3945
var _op = tf._op_def_lib._apply_op_helper("ReadFile", name: name, args: new { filename });
4046

4147
return _op.outputs[0];
4248
}
49+
50+
private Tensor read_file_eager_fallback<T>(T filename, string name = null, Context ctx = null)
51+
{
52+
var filename_tensor = ops.convert_to_tensor(filename, TF_DataType.TF_STRING);
53+
var _inputs_flat = new[] { filename_tensor };
54+
55+
return tf._execute.execute(ctx, "ReadFile", 1, _inputs_flat, null, name: name)[0];
56+
}
4357
}
4458
}

src/TensorFlowNET.Core/Operations/string_ops.cs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Collections.Generic;
1919
using System.Text;
20+
using static Tensorflow.Binding;
2021

2122
namespace Tensorflow
2223
{
@@ -31,8 +32,30 @@ public class string_ops
3132
/// <param name="name"></param>
3233
/// <param name="uint"></param>
3334
/// <returns></returns>
34-
public static Tensor substr(Tensor input, int pos, int len,
35-
string name = null, string @uint = "BYTE")
36-
=> gen_string_ops.substr(input, pos, len, name: name, @uint: @uint);
35+
public Tensor substr<T>(T input, int pos, int len,
36+
string @uint = "BYTE", string name = null)
37+
{
38+
if (tf.context.executing_eagerly())
39+
{
40+
var input_tensor = tf.constant(input);
41+
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
42+
"Substr", name,
43+
null,
44+
input, pos, len,
45+
"unit", @uint);
46+
47+
return results[0];
48+
}
49+
50+
var _op = tf._op_def_lib._apply_op_helper("Substr", name: name, args: new
51+
{
52+
input,
53+
pos,
54+
len,
55+
unit = @uint
56+
});
57+
58+
return _op.output;
59+
}
3760
}
3861
}

src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ public T ToScalar<T>()
6868
throw new ArgumentException($"{nameof(Tensor)} can only be scalar.");
6969

7070
IntPtr stringStartAddress = IntPtr.Zero;
71-
UIntPtr dstLen = UIntPtr.Zero;
71+
ulong dstLen = 0;
7272

73-
c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status.Handle);
73+
c_api.TF_StringDecode((byte*) this.buffer + 8, this.bytesize, (byte**) &stringStartAddress, ref dstLen, tf.status.Handle);
7474
tf.status.Check(true);
7575

7676
var dstLenInt = checked((int) dstLen);

0 commit comments

Comments
 (0)