Skip to content

Commit 1647591

Browse files
committed
merge with master.
2 parents 990e774 + 174241d commit 1647591

21 files changed

Lines changed: 639 additions & 200 deletions

File tree

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
# TensorFlow.NET
2-
![logo](docs/assets/tf.net.logo.svg)
1+
![logo](docs/assets/tf.net.logo.png)
32

4-
TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
3+
**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
54

65
[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
76
[![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/wx4td43v2d3f2xj6?svg=true)](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
@@ -10,7 +9,8 @@ TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https:
109
[![Documentation Status](https://readthedocs.org/projects/tensorflownet/badge/?version=latest)](https://tensorflownet.readthedocs.io/en/latest/?badge=latest)
1110
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)
1211

13-
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).
12+
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). <a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp_badge.png" width="200" height="200" align="right" /></a>
13+
1414

1515
![tensors_flowing](docs/assets/tensors_flowing.gif)
1616

@@ -200,5 +200,6 @@ Scan QR code to join Tencent TIM group:
200200

201201
![SciSharp STACK](docs/TIM.jpg)
202202

203-
![SciSharp](https://avatars3.githubusercontent.com/u/44989469) TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
204-
203+
TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
204+
<br>
205+
<a href="http://scisharpstack.org"><img src="https://github.com/SciSharp/SciSharp/blob/master/art/scisharp-stack.png" width="391" height="100" /></a>

docs/assets/tf.net.logo.png

18 KB
Loading

docs/assets/tf.net.logo.svg

Lines changed: 138 additions & 155 deletions
Loading

src/TensorFlowNET.Core/APIs/tf.control.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,29 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
18+
1719
namespace Tensorflow
1820
{
1921
public static partial class tf
2022
{
23+
public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor> body, Tensor[] loop_vars,
24+
TensorShape shape_invariants = null,
25+
int parallel_iterations = 10,
26+
bool back_prop = true,
27+
bool swap_memory = false,
28+
string name = null,
29+
int? maximum_iterations = null,
30+
bool return_same_structure = false)
31+
=> control_flow_ops.while_loop(cond, body, loop_vars,
32+
shape_invariants: shape_invariants,
33+
parallel_iterations: parallel_iterations,
34+
back_prop: back_prop,
35+
swap_memory: swap_memory,
36+
name: name,
37+
maximum_iterations: maximum_iterations,
38+
return_same_structure: return_same_structure);
39+
2140
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
2241
=> ops.control_dependencies(control_inputs);
2342
}

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ public static Tensor acos(Tensor x, string name = null)
3939
public static Tensor asin(Tensor x, string name = null)
4040
=> gen_math_ops.asin(x, name);
4141

42-
public static Tensor add<Tx, Ty>(Tx a, Ty b)
43-
=> gen_math_ops.add(a, b);
42+
public static Tensor add<Tx, Ty>(Tx a, Ty b, string name = null)
43+
=> gen_math_ops.add(a, b, name: name);
4444

4545
/// <summary>
4646
/// Computes atan of x element-wise.
@@ -198,6 +198,9 @@ public static Tensor logical_not(Tensor x, string name = null)
198198
public static Tensor logical_or(Tensor x, Tensor y, string name = null)
199199
=> gen_math_ops.logical_or(x, y, name);
200200

201+
public static Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor")
202+
=> gen_math_ops.logical_xor(x, y, name);
203+
201204
/// <summary>
202205
/// Clips tensor values to a specified min and max.
203206
/// </summary>

src/TensorFlowNET.Core/Framework/Models/ScopedTFGraph.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,5 @@ public ScopedTFGraph() : base()
66
{
77

88
}
9-
10-
~ScopedTFGraph()
11-
{
12-
base.Dispose();
13-
}
149
}
1510
}

src/TensorFlowNET.Core/Graphs/Graph.Control.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public partial class Graph
3333
/// </summary>
3434
/// <param name="input_ops">The data input ops for an op to be created.</param>
3535
/// <returns>A list of control inputs for the op to be created.</returns>
36-
private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
36+
public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops)
3737
{
3838
var ret = new List<ITensorOrOperation>();
3939

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ protected override void DisposeManagedState()
445445

446446
protected override void DisposeUnManagedState(IntPtr handle)
447447
{
448-
Console.WriteLine($"Destroy graph {handle}");
449448
c_api.TF_DeleteGraph(handle);
450449
}
451450

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ public Tensor pivot
5353
protected Stack<ControlFlowContext> _context_stack;
5454
protected ControlFlowContext _outer_context;
5555

56+
/// <summary>
57+
/// The keys are the names of tensors referenced by but external to this
58+
/// context. Each value is the Tensor that should be used by this context to
59+
/// access the key value (e.g. a switch output guarding a cond input value).
60+
/// </summary>
5661
protected Dictionary<string, ITensorOrOperation> _external_values;
5762

5863
public ControlFlowContext()
@@ -68,6 +73,12 @@ public void __init__(ValuesDef values_def = null, string import_scope = null)
6873
_outer_context = ops.get_default_graph()._get_control_flow_context();
6974
if (values_def != null)
7075
_init_values_from_proto(values_def, import_scope: import_scope);
76+
else
77+
{
78+
_values = new HashSet<string>();
79+
_external_values = new Dictionary<string, ITensorOrOperation>();
80+
}
81+
7182
}
7283

7384
public void __enter__()
@@ -114,6 +125,27 @@ public virtual void Enter()
114125
graph._set_control_flow_context(this);
115126
}
116127

128+
protected virtual Tensor _Enter(Tensor data, string frame_name,
129+
bool is_constant = false,
130+
int parallel_iterations = 10,
131+
bool use_ref = true,
132+
bool use_input_shape = true,
133+
string name = null)
134+
{
135+
Tensor result;
136+
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref: true);
137+
if (data.dtype.is_ref_dtype() && use_ref)
138+
throw new NotImplementedException("_Enter");
139+
else
140+
result = gen_control_flow_ops.enter(
141+
data, frame_name, is_constant, parallel_iterations, name: name);
142+
143+
if (use_input_shape)
144+
result.SetShape(data.TensorShape);
145+
146+
return result;
147+
}
148+
117149
/// <summary>
118150
/// Exit this control flow context.
119151
/// </summary>
@@ -184,6 +216,10 @@ public static bool IsContainingContext(ControlFlowContext ctxt, ControlFlowConte
184216
return true;
185217
}
186218

219+
protected virtual bool _IsInOuterContext(Operation op)
220+
{
221+
throw new NotImplementedException("_IsInOuterContext");
222+
}
187223

188224
protected virtual void _RemoveExternalControlEdges(Operation op)
189225
{

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 165 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using System.Collections.Generic;
19+
using System.Linq;
1820
using Tensorflow.Operations.ControlFlows;
21+
using Tensorflow.Util;
1922
using static Tensorflow.Python;
23+
using static Tensorflow.control_flow_ops;
2024

2125
namespace Tensorflow.Operations
2226
{
@@ -32,10 +36,14 @@ public class WhileContext : ControlFlowContext
3236
bool _swap_memory;
3337
Tensor _pivot_for_pred;
3438
Tensor _pivot_for_body;
35-
Tensor[] _loop_exits;
36-
Tensor[] _loop_enters;
39+
List<Tensor> _loop_exits;
40+
List<Tensor> _loop_enters;
41+
Graph _graph;
42+
public override GradLoopState grad_state => _grad_state;
43+
public override bool back_prop => _back_prop;
3744

38-
public WhileContext(int parallel_iterations = 10,
45+
public WhileContext(int? maximum_iterations = null,
46+
int parallel_iterations = 10,
3947
bool back_prop = true,
4048
bool swap_memory = false,
4149
string name = "while_context",
@@ -49,12 +57,27 @@ public WhileContext(int parallel_iterations = 10,
4957
}
5058
else
5159
{
52-
60+
__init__();
61+
_init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name);
5362
}
5463

5564
_grad_state = grad_state;
5665
}
5766

67+
private void _init_from_args(int? maximum_iterations,
68+
int parallel_iterations,
69+
bool back_prop,
70+
bool swap_memory,
71+
string name)
72+
{
73+
_name = ops.get_default_graph().unique_name(name);
74+
_back_prop = back_prop;
75+
_swap_memory = swap_memory;
76+
_loop_exits = new List<Tensor>();
77+
_loop_enters = new List<Tensor>();
78+
_graph = ops.get_default_graph();
79+
}
80+
5881
private void _init_from_proto(WhileContextDef context_def, string import_scope = null)
5982
{
6083
var g = ops.get_default_graph();
@@ -70,26 +93,156 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
7093
// The boolean tensor for loop termination condition.
7194
_pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor;
7295
// The list of exit tensors for loop variables.
73-
_loop_exits = new Tensor[context_def.LoopExitNames.Count];
96+
_loop_exits = new List<Tensor>();
7497
foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames))
75-
_loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor;
98+
_loop_exits.Add(g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor);
7699
// The list of enter tensors for loop variables.
77-
_loop_enters = new Tensor[context_def.LoopEnterNames.Count];
100+
_loop_enters = new List<Tensor>();
78101
foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames))
79-
_loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor;
102+
_loop_enters.Add(g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor);
80103

81104
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
82105
}
83106

84-
public override WhileContext GetWhileContext()
107+
/// <summary>
108+
/// Add the loop termination condition and body to the graph.
109+
/// </summary>
110+
public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
111+
Func<Tensor, Tensor> body,
112+
Tensor[] loop_vars,
113+
TensorShape shape_invariants,
114+
bool return_same_structure)
85115
{
86-
return this;
116+
// Keep original_loop_vars to identify which are TensorArrays
117+
var original_loop_vars = loop_vars;
118+
// Convert TensorArrays to their flow variables
119+
Enter();
120+
var(original_body_result, exit_vars) = _BuildLoop(
121+
pred, body, original_loop_vars, loop_vars, shape_invariants);
122+
Exit();
123+
124+
var flat_result = original_body_result;
125+
126+
var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars);
127+
var packed_exit_vars = nest.pack_sequence_as(
128+
structure: original_body_result,
129+
flat_sequence: exit_vars_with_tensor_arrays);
130+
131+
return packed_exit_vars as Tensor[];
87132
}
88133

134+
private (Tensor[], Tensor[]) _BuildLoop(Func<Tensor, Tensor> pred,
135+
Func<Tensor, Tensor> body,
136+
Tensor[] original_loop_vars,
137+
Tensor[] loop_vars,
138+
TensorShape shape_invariants)
139+
{
140+
var flat_loop_vars = original_loop_vars;
89141

90-
public override GradLoopState grad_state => _grad_state;
142+
// Let the context know the loop variables so the loop variables
143+
// would be added in the outer contexts properly.
144+
_InitializeValues(loop_vars);
145+
var real_vars = loop_vars;
146+
Tensor[] enter_vars = null;
147+
tf_with(ops.control_dependencies(null), delegate
148+
{
149+
enter_vars = real_vars.Select(x => _Enter(x,
150+
_name,
151+
is_constant: false,
152+
parallel_iterations: _parallel_iterations,
153+
use_input_shape: shape_invariants == null))
154+
.ToArray();
91155

92-
public override bool back_prop => _back_prop;
156+
foreach(var x in enter_vars)
157+
{
158+
x.graph.prevent_feeding(x);
159+
if (_outer_context != null)
160+
_outer_context.AddInnerOp(x.op);
161+
}
162+
});
163+
164+
// Finds the closest enclosing non-None control pivot.
165+
var outer_context = _outer_context;
166+
while (outer_context != null)
167+
{
168+
169+
}
170+
171+
_SetShapeInvariants(real_vars, enter_vars, shape_invariants);
172+
173+
// Fix the control inputs and control flow context of these enter ops.
174+
_FixControlInputsAndContext(enter_vars);
175+
_InitializeValues(enter_vars);
176+
_loop_enters = enter_vars.ToList();
177+
178+
var merge_vars = enter_vars
179+
.Select(x => merge(new[] { x, x }))
180+
.ToArray();
181+
182+
_pivot_for_pred = merge_vars[0];
183+
184+
// Build the graph for pred.
185+
var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
186+
// var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
187+
var c = ops.convert_to_tensor(pred(merge_vars_with_tensor_arrays[0]));
188+
_pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
189+
var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
190+
.ToArray();
191+
192+
// Build the graph for body.
193+
var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();
194+
// Convert TensorArray flow variables inside the context back into
195+
// their associated TensorArrays for calling the body.
196+
var packed_vars_for_body = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
197+
var body_result = body(packed_vars_for_body[0]);
198+
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
199+
200+
// Store body_result to keep track of TensorArrays returned by body
201+
var original_body_result = new[] { body_result };
202+
// Convert TensorArrays returned by body into their flow variables
203+
var result = new[] { body_result };
204+
205+
var next_vars = new List<Tensor>();
206+
foreach (var (m, v) in zip(merge_vars, result))
207+
next_vars.Add(_AddNextAndBackEdge(m, v));
208+
209+
// Add the exit ops.
210+
var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
211+
_loop_exits = exit_vars;
212+
213+
// Exit the loop.
214+
// ExitResult(exit_vars);
215+
return (original_body_result, exit_vars.ToArray());
216+
}
217+
218+
private void _FixControlInputsAndContext(Tensor[] enters)
219+
{
220+
var graph = ops.get_default_graph();
221+
foreach(var e in enters)
222+
{
223+
var inp_op = e.op.inputs[0].op;
224+
var control_inputs = graph._control_dependencies_for_inputs(new[] { inp_op });
225+
// op for op in control_inputs if self._IsInOuterContext(op)
226+
var outer_control_inputs = control_inputs.Where(x => _IsInOuterContext(x.op))
227+
.Select(x => x.op)
228+
.ToArray();
229+
e.op._set_control_flow_context(this);
230+
e.op._add_control_inputs(outer_control_inputs);
231+
graph._record_op_seen_by_control_dependencies(e.op);
232+
}
233+
}
234+
235+
private void _InitializeValues(Tensor[] values)
236+
{
237+
_values = new HashSet<string>();
238+
foreach(var x in values)
239+
_values.Add(x.name);
240+
}
241+
242+
public override WhileContext GetWhileContext()
243+
{
244+
return this;
245+
}
93246

94247
public WhileContext from_proto(WhileContextDef proto, string import_scope)
95248
{

0 commit comments

Comments
 (0)