Skip to content

Commit fe3440f

Browse files
committed
Merge remote-tracking branch 'upstream/master'
# Conflicts: # src/TensorFlowNET.Core/Operations/Operation.cs # test/TensorFlowNET.UnitTest/PythonTest.cs # test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
2 parents 9dba680 + 6bd08d6 commit fe3440f

17 files changed

Lines changed: 1672 additions & 396 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public static Tensor acos(Tensor x, string name = null)
2727
public static Tensor asin(Tensor x, string name = null)
2828
=> gen_math_ops.asin(x, name);
2929

30-
public static Tensor add(Tensor a, Tensor b)
30+
public static Tensor add<Tx, Ty>(Tx a, Ty b)
3131
=> gen_math_ops.add(a, b);
3232

3333
/// <summary>
@@ -251,7 +251,7 @@ public static Tensor maximum<T1, T2>(T1 x, T2 y, string name = null)
251251
public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null)
252252
=> gen_math_ops.minimum(x, y, name: name);
253253

254-
public static Tensor multiply(Tensor x, Tensor y)
254+
public static Tensor multiply<Tx, Ty>(Tx x, Ty y)
255255
=> gen_math_ops.mul(x, y);
256256

257257
public static Tensor negative(Tensor x, string name = null)

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.IO;
55
using System.Linq;
66
using System.Text;
7+
using Tensorflow.Operations;
78
using static Tensorflow.CollectionDef;
89
using static Tensorflow.MetaGraphDef.Types;
910

@@ -95,15 +96,29 @@ public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_sco
9596
}
9697
else
9798
{
98-
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
99+
foreach(var value in col.Value.BytesList.Value)
100+
{
101+
switch (col.Key)
102+
{
103+
case "cond_context":
104+
var proto = CondContextDef.Parser.ParseFrom(value);
105+
var condContext = new CondContext().from_proto(proto, import_scope);
106+
graph.add_to_collection(col.Key, condContext);
107+
break;
108+
default:
109+
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
110+
}
111+
}
99112
}
100113

101114
break;
115+
default:
116+
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
102117
}
103118
}
104119

105-
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
106-
scope: scope_to_prepend_to_names) as List<RefVariable>;
120+
var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES,
121+
scope: scope_to_prepend_to_names);
107122
var var_list = new Dictionary<string, RefVariable>();
108123
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);
109124

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@ public object get_collection(string name, string scope = null)
412412
return _collections.ContainsKey(name) ? _collections[name] : null;
413413
}
414414

415+
public List<T> get_collection<T>(string name, string scope = null)
416+
{
417+
return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>();
418+
}
419+
415420
public object get_collection_ref(string name)
416421
{
417422
if (!_collections.ContainsKey(name))

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow.Operations
88
/// <summary>
99
/// The context for the conditional construct.
1010
/// </summary>
11-
public class CondContext : ControlFlowContext
11+
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
1212
{
1313

1414

@@ -35,16 +35,20 @@ public class CondContext : ControlFlowContext
3535
/// <param name="name">Name of the `CondContext` python object.</param>
3636
/// <param name="context_def"></param>
3737
/// <param name="import_scope"></param>
38-
public CondContext(Tensor pred,
39-
Tensor pivot,
40-
int branch,
38+
public CondContext(Tensor pred = null,
39+
Tensor pivot = null,
40+
int? branch = null,
4141
string name = "cond_text",
42-
object context_def = null,
42+
CondContextDef context_def = null,
4343
string import_scope = null)
4444
{
45+
if (pred == null && context_def == null) return;
46+
4547
_name = ops.get_default_graph().unique_name(name);
46-
if (context_def != null)
47-
throw new NotImplementedException("CondContext context_def is not null");
48+
if (context_def != null)
49+
{
50+
_init_from_proto(context_def, import_scope: import_scope);
51+
}
4852
else
4953
{
5054
// Initializes the default fields.
@@ -61,6 +65,18 @@ public CondContext(Tensor pred,
6165
}
6266
}
6367

68+
private void _init_from_proto(CondContextDef context_def, string import_scope = null)
69+
{
70+
var g = ops.get_default_graph();
71+
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
72+
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope);
73+
_pred = g.as_graph_element(p1) as Tensor;
74+
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope);
75+
_pivot = g.as_graph_element(p2) as Tensor;
76+
_branch = context_def.Branch;
77+
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
78+
}
79+
6480
/// <summary>
6581
/// Add `val` to the current context and its outer context recursively.
6682
/// </summary>
@@ -230,6 +246,22 @@ private Tensor _ProcessOutputTensor(Tensor val)
230246
public override void AddInnerOp(Operation resultOp)
231247
{
232248
throw new NotImplementedException();
233-
}
249+
}
250+
251+
public CondContextDef to_proto(string export_scope)
252+
{
253+
throw new NotImplementedException();
254+
}
255+
256+
public CondContext from_proto(CondContextDef proto, string import_scope)
257+
{
258+
var ret = new CondContext(context_def: proto, import_scope: import_scope);
259+
260+
ret.Enter();
261+
foreach (var nested_def in proto.NestedContexts)
262+
throw new NotImplementedException("");
263+
ret.Exit();
264+
return ret;
265+
}
234266
}
235267
}

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public abstract class ControlFlowContext : Python, IPython, IControlFlowContext
3232
protected Stack<IControlFlowContext> _context_stack;
3333
protected IControlFlowContext _outer_context;
3434

35+
protected Dictionary<string, ITensorOrOperation> _external_values;
36+
3537
public ControlFlowContext()
3638
{
3739
_context_stack = new Stack<IControlFlowContext>();
@@ -40,15 +42,43 @@ public ControlFlowContext()
4042
public string name { get => _name; }
4143
protected string _name;
4244

43-
public void __init__()
45+
public void __init__(ValuesDef values_def = null, string import_scope = null)
4446
{
45-
47+
_outer_context = ops.get_default_graph()._get_control_flow_context();
48+
if (values_def != null)
49+
_init_values_from_proto(values_def, import_scope: import_scope);
4650
}
4751

4852
public void __enter__()
4953
{
5054
}
5155

56+
/// <summary>
57+
/// Initializes values and external_values from `ValuesDef` protocol buffer.
58+
/// </summary>
59+
/// <param name="values_def"></param>
60+
/// <param name="import_scope"></param>
61+
protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null)
62+
{
63+
_external_values = new Dictionary<string, ITensorOrOperation>();
64+
foreach (var value in values_def.Values)
65+
_values.Add(value);
66+
var g = ops.get_default_graph();
67+
foreach(var value in values_def.ExternalValues)
68+
{
69+
var k = ops.prepend_name_scope(value.Key, import_scope);
70+
var v = value.Value;
71+
_external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope));
72+
}
73+
74+
var op_names = _values.Where(x => !_external_values.ContainsKey(x))
75+
.Select(x => x.Split(':')[0])
76+
.ToArray();
77+
78+
foreach (var op in op_names)
79+
(g.as_graph_element(op) as Operation)._set_control_flow_context(this);
80+
}
81+
5282
public void __exit__()
5383
{
5484
}

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ public unsafe Operation[] GetControlOutputs()
4242
if (NumControlOutputs > 0)
4343
{
4444
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
45-
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
46-
for (int i = 0; i < NumControlInputs; i++)
45+
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs);
46+
for (int i = 0; i < NumControlOutputs; i++)
4747
{
4848
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
4949
control_outputs[i] = new Operation(*(IntPtr*)handle);

0 commit comments

Comments
 (0)