Skip to content

Commit 14f8adb

Browse files
committed
add FuncGraph.
1 parent 242e051 commit 14f8adb

12 files changed

Lines changed: 279 additions & 22 deletions

File tree

src/TensorFlowNET.Console/TensorFlowNET.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*****************************************************************************
2+
Copyright 2020 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using Tensorflow.Graphs;
18+
using Tensorflow.Operations;
19+
20+
namespace Tensorflow
21+
{
22+
public partial class tensorflow
23+
{
24+
public AutoGraph autograph = new AutoGraph();
25+
}
26+
}

src/TensorFlowNET.Core/APIs/tf.control_flow.cs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow
2021
{
2122
public partial class tensorflow
2223
{
24+
public Tensor cond(Tensor pred,
25+
Tensor true_value,
26+
Tensor false_false)
27+
=> control_flow_ops.cond(pred, () => true_value, () => false_false);
28+
2329
public Tensor cond(Tensor pred,
2430
Func<ITensorOrOperation> true_fn = null,
2531
Func<ITensorOrOperation> false_fn = null,
26-
bool strict = false,
2732
string name = null)
28-
=> control_flow_ops.cond(pred, true_fn, false_fn, strict: strict, name: name);
33+
=> control_flow_ops.cond(pred, true_fn, false_fn, name: name);
2934

3035
/// <summary>
3136
/// Create an op that groups multiple operations.
@@ -37,22 +42,31 @@ public Tensor cond(Tensor pred,
3742
public Operation group<T>(T[] inputs, string name = null) where T : ITensorOrOperation
3843
=> control_flow_ops.group(inputs, name: name);
3944

40-
/*public Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
41-
TensorShape shape_invariants = null,
45+
public Tensor while_loop(Func<Tensor, Tensor> cond,
46+
Func<Tensor, Tensor> body,
47+
Tensor loop_vars,
48+
int parallel_iterations = 10)
49+
{
50+
Func<Tensor[], Tensor> cond1 = x
51+
=> cond(x[0]);
52+
53+
Func<Tensor[], Tensor[]> body1 = x
54+
=> new[] { body(x[0]) };
55+
56+
var results = control_flow_ops.while_loop(cond1,
57+
body1,
58+
new[] { loop_vars });
59+
return results[0];
60+
}
61+
62+
public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
63+
Func<Tensor[], Tensor[]> body,
64+
Tensor[] loop_vars,
4265
int parallel_iterations = 10,
43-
bool back_prop = true,
44-
bool swap_memory = false,
45-
string name = null,
46-
int? maximum_iterations = null,
47-
bool return_same_structure = false)
66+
string name = null)
4867
=> control_flow_ops.while_loop(cond, body, loop_vars,
49-
shape_invariants: shape_invariants,
5068
parallel_iterations: parallel_iterations,
51-
back_prop: back_prop,
52-
swap_memory: swap_memory,
53-
name: name,
54-
maximum_iterations: maximum_iterations,
55-
return_same_structure: return_same_structure);*/
69+
name: name);
5670

5771
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
5872
=> ops.control_dependencies(control_inputs);

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,37 @@ public partial class c_api
7878
[DllImport(TensorFlowLibName)]
7979
public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status);
8080

81+
/// <summary>
82+
/// Adds a function (created from TF_GraphToFunction or
83+
/// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
84+
/// TFE_Execute by creating an op with the same name as the function.
85+
/// </summary>
86+
/// <param name="ctx"></param>
87+
/// <param name="function"></param>
88+
/// <param name="status"></param>
89+
[DllImport(TensorFlowLibName)]
90+
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status);
91+
92+
/// <summary>
93+
/// Removes a function from the context. Once removed, you can no longer
94+
/// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
95+
/// other function which calls it as an attribute.
96+
/// </summary>
97+
/// <param name="ctx"></param>
98+
/// <param name="name"></param>
99+
/// <param name="status"></param>
100+
[DllImport(TensorFlowLibName)]
101+
public static extern void TFE_ContextRemoveFunction(SafeContextHandle ctx, string name, SafeStatusHandle status);
102+
103+
/// <summary>
104+
/// Checks whether a function is registered under `name`.
105+
/// </summary>
106+
/// <param name="ctx"></param>
107+
/// <param name="name"></param>
108+
/// <returns></returns>
109+
[DllImport(TensorFlowLibName)]
110+
public static extern bool TFE_ContextHasFunction(SafeContextHandle ctx, string name);
111+
81112
[DllImport(TensorFlowLibName)]
82113
public static extern void TFE_ContextStartStep(SafeContextHandle ctx);
83114

src/TensorFlowNET.Core/Functions/c_api.function.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name,
3939
int num_opers, IntPtr[] opers,
4040
int ninputs, TF_Output[] inputs,
4141
int noutputs, TF_Output[] outputs,
42-
IntPtr output_names,
42+
string[] output_names,
4343
IntPtr opts,
4444
string description,
4545
SafeStatusHandle status);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Linq.Expressions;
5+
using System.Text;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Graphs
9+
{
10+
public class AutoGraph
11+
{
12+
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
13+
{
14+
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
15+
tf.compat.v1.disable_eager_execution();
16+
// IntPtr func_handle;
17+
using(var graph = new FuncGraph(func_name))
18+
{
19+
graph.as_default();
20+
var input1 = tf.placeholder(tf.int32);
21+
var input2 = tf.placeholder(tf.int32);
22+
var output = func(input1, input2);
23+
24+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
25+
var func_handle = graph.ToGraph(opers,
26+
new Operation[] { input1, input2 },
27+
new Operation[] { output },
28+
null);
29+
30+
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, tf.Status.Handle);
31+
}
32+
33+
tf.enable_eager_execution();
34+
35+
return (Tensor a, Tensor b) =>
36+
{
37+
var result = tf.Runner.TFE_Execute(tf.Context,
38+
tf.Context.DeviceName,
39+
func_name,
40+
new[] { a, b },
41+
null,
42+
1);
43+
return result[0];
44+
};
45+
}
46+
}
47+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*using MethodBoundaryAspect.Fody.Attributes;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow.Eager;
7+
using static Tensorflow.Binding;
8+
9+
namespace Tensorflow.Graphs
10+
{
11+
public sealed class AutoGraphAspect : OnMethodBoundaryAspect
12+
{
13+
FuncGraph graph;
14+
IntPtr func_handle;
15+
16+
public override void OnEntry(MethodExecutionArgs args)
17+
{
18+
tf.compat.v1.disable_eager_execution();
19+
// convert args to placeholder
20+
21+
for (var i = 0; i < args.Arguments.Length; i++)
22+
{
23+
if (args.Arguments[i] is EagerTensor tensor)
24+
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
25+
}
26+
27+
// make function as an Operation by autograph
28+
graph = new FuncGraph("autograph_add");
29+
graph.as_default();
30+
}
31+
32+
public override void OnExit(MethodExecutionArgs args)
33+
{
34+
var output = (Tensor)args.Method.Invoke(args.Instance, args.Arguments);
35+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
36+
func_handle = graph.ToGraph(opers,
37+
new Operation[] { },
38+
new Operation[] { },
39+
null);
40+
41+
42+
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, tf.Status.Handle);
43+
44+
var a1 = tf.constant(1);
45+
var b1 = tf.constant(2);
46+
47+
var result = tf.Runner.TFE_Execute(tf.Context,
48+
tf.Context.DeviceName,
49+
"autograph_add",
50+
new[] { a1, b1 },
51+
null,
52+
1);
53+
graph.Dispose();
54+
}
55+
}
56+
}*/
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.Graphs
8+
{
9+
/// <summary>
10+
/// Graph representing a function body.
11+
/// </summary>
12+
public class FuncGraph : Graph
13+
{
14+
List<Operation> inputs;
15+
List<Operation> outputs;
16+
Graph outer_graph;
17+
string func_name;
18+
IntPtr func_handle;
19+
public string FuncName => c_api.StringPiece(c_api.TF_FunctionName(func_handle));
20+
21+
/// <summary>
22+
/// Construct a new FuncGraph.
23+
/// </summary>
24+
public FuncGraph(string name) : base()
25+
{
26+
outer_graph = ops.get_default_graph();
27+
func_name = name;
28+
}
29+
30+
public IntPtr ToGraph(Operation[] opers,
31+
Operation[] inputs, Operation[] outputs,
32+
string[] output_names)
33+
{
34+
using var status = new Status();
35+
func_handle = c_api.TF_GraphToFunction(_handle,
36+
func_name,
37+
false,
38+
opers.Length,
39+
opers.Select(x => (IntPtr)x).ToArray(),
40+
inputs.Length,
41+
inputs.Select(x => new TF_Output(x, 0)).ToArray(),
42+
outputs.Length,
43+
outputs.Select(x => new TF_Output(x, 0)).ToArray(),
44+
output_names == null || output_names.Length == 0 ? null : output_names,
45+
IntPtr.Zero,
46+
null,
47+
status.Handle);
48+
49+
c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle);
50+
51+
return func_handle;
52+
}
53+
}
54+
}

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
using util = Tensorflow.control_flow_util;
2323
using static Tensorflow.Binding;
2424
using Tensorflow.Util;
25+
using System.Data;
2526

2627
namespace Tensorflow
2728
{
@@ -420,14 +421,13 @@ public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name
420421
public static Tensor cond(Tensor pred,
421422
Func<ITensorOrOperation> true_fn = null,
422423
Func<ITensorOrOperation> false_fn = null,
423-
bool strict = false,
424424
string name = null)
425425
{
426426
return tf_with(ops.name_scope(name, "cond", new { pred }), delegate
427427
{
428428
if (tf.Context.executing_eagerly())
429429
{
430-
if (pred.ToArray<bool>()[0])
430+
if ((bool)pred)
431431
return true_fn() as Tensor;
432432
else
433433
return false_fn() as Tensor;
@@ -676,6 +676,29 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
676676
}
677677
}
678678

679+
public static Tensor[] while_loop(Func<Tensor[], Tensor> cond,
680+
Func<Tensor[], Tensor[]> body,
681+
Tensor[] loop_vars,
682+
int parallel_iterations = 10,
683+
string name = null)
684+
{
685+
var executing_eagerly = tf.Context.executing_eagerly();
686+
if (!executing_eagerly)
687+
{
688+
throw new NotImplementedException("");
689+
}
690+
691+
return tf_with(ops.name_scope("name", "while"), delegate
692+
{
693+
while ((bool)cond(loop_vars))
694+
{
695+
loop_vars = body(loop_vars);
696+
}
697+
698+
return loop_vars;
699+
});
700+
}
701+
679702
/// <summary>
680703
/// Repeat `body` while the condition `cond` is true.
681704
/// </summary>

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ https://tensorflownet.readthedocs.io</Description>
2828
<FileVersion>0.20.1.0</FileVersion>
2929
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3030
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
31-
<SignAssembly>true</SignAssembly>
31+
<SignAssembly>false</SignAssembly>
3232
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
3333
<Platforms>AnyCPU;x64</Platforms>
3434
</PropertyGroup>
@@ -83,4 +83,10 @@ https://tensorflownet.readthedocs.io</Description>
8383
<ItemGroup>
8484
<Folder Include="Keras\Initializers\" />
8585
</ItemGroup>
86+
87+
<ItemGroup>
88+
<None Update="FodyWeavers.xml">
89+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
90+
</None>
91+
</ItemGroup>
8692
</Project>

0 commit comments

Comments
 (0)