Skip to content

Commit 3f9f2a3

Browse files
committed
changed the class Python to be static
1 parent a38bd5d commit 3f9f2a3

64 files changed

Lines changed: 339 additions & 277 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/TensorFlowNET.Core/Clustering/KMeans.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using static Tensorflow.Python;
45

56
namespace Tensorflow.Clustering
67
{
78
/// <summary>
89
/// Creates the graph for k-means clustering.
910
/// </summary>
10-
public class KMeans : Python
11+
public class KMeans
1112
{
1213
public const string CLUSTERS_VAR_NAME = "clusters";
1314

src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using static Tensorflow.Python;
56

67
namespace Tensorflow.Clustering
78
{
89
/// <summary>
910
/// Internal class to create the op to initialize the clusters.
1011
/// </summary>
11-
public class _InitializeClustersOpFactory : Python
12+
public class _InitializeClustersOpFactory
1213
{
1314
Tensor[] _inputs;
1415
Tensor _num_clusters;

src/TensorFlowNET.Core/Framework/importer.py.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
using System.Linq;
55
using System.Text;
66
using static Tensorflow.OpDef.Types;
7+
using static Tensorflow.Python;
78

89
namespace Tensorflow
910
{
10-
public class importer : Python
11+
public class importer
1112
{
1213
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
1314
Dictionary<string, Tensor> input_map = null,

src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
using System.Linq;
55
using System.Text;
66
using System.Threading;
7+
using static Tensorflow.Python;
78

89
namespace Tensorflow
910
{
10-
public class gradients_impl : Python
11+
public class gradients_impl
1112
{
1213
public static Tensor[] gradients(Tensor[] ys,
1314
Tensor[] xs,

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
using System.Collections.Generic;
44
using System.Linq;
55
using System.Text;
6+
using static Tensorflow.Python;
67

78
namespace Tensorflow.Gradients
89
{
910
/// <summary>
1011
/// Gradients for operators defined in math_ops.py.
1112
/// </summary>
12-
public class math_grad : Python
13+
public class math_grad
1314
{
1415
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
1516
{

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 111 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Linq;
44
using System.Text;
55
using Tensorflow.Operations.ControlFlows;
6+
using static Tensorflow.Python;
67

78
namespace Tensorflow.Operations
89
{
@@ -46,9 +47,9 @@ public CondContext(Tensor pred = null,
4647
if (pred == null && context_def == null) return;
4748

4849
_name = ops.get_default_graph().unique_name(name);
49-
if (context_def != null)
50-
{
51-
_init_from_proto(context_def, import_scope: import_scope);
50+
if (context_def != null)
51+
{
52+
_init_from_proto(context_def, import_scope: import_scope);
5253
}
5354
else
5455
{
@@ -66,16 +67,16 @@ public CondContext(Tensor pred = null,
6667
}
6768
}
6869

69-
private void _init_from_proto(CondContextDef context_def, string import_scope = null)
70-
{
71-
var g = ops.get_default_graph();
72-
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
73-
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope);
74-
_pred = g.as_graph_element(p1) as Tensor;
75-
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope);
76-
_pivot = g.as_graph_element(p2) as Tensor;
77-
_branch = context_def.Branch;
78-
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
70+
private void _init_from_proto(CondContextDef context_def, string import_scope = null)
71+
{
72+
var g = ops.get_default_graph();
73+
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
74+
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope);
75+
_pred = g.as_graph_element(p1) as Tensor;
76+
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope);
77+
_pivot = g.as_graph_element(p2) as Tensor;
78+
_branch = context_def.Branch;
79+
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
7980
}
8081

8182
/// <summary>
@@ -90,8 +91,8 @@ public override Tensor AddValue(Tensor val)
9091
// Use the real value if it comes from outer context. This is needed in
9192
// particular for nested conds.
9293
if (_external_values.ContainsKey(val.name))
93-
result = _external_values[val.name];
94-
94+
result = _external_values[val.name];
95+
9596
result = result == null ? val : result;
9697
}
9798
else
@@ -107,10 +108,10 @@ public override Tensor AddValue(Tensor val)
107108
}
108109

109110
with(ops.control_dependencies(null), ctrl =>
110-
{
111-
var results = control_flow_ops._SwitchRefOrTensor(result, _pred);
112-
result = results[_branch];
113-
if (_outer_context != null)
111+
{
112+
var results = control_flow_ops._SwitchRefOrTensor(result, _pred);
113+
result = results[_branch];
114+
if (_outer_context != null)
114115
_outer_context.AddInnerOp(result.op);
115116
});
116117

@@ -127,87 +128,87 @@ public override Tensor AddValue(Tensor val)
127128
}
128129
_external_values[val.name] = result;
129130
}
130-
return result;
131-
}
132-
131+
return result;
132+
}
133+
133134
/// <summary>
134135
/// Add the subgraph defined by fn() to the graph.
135136
/// </summary>
136-
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
137-
{
138-
// Add the subgraph defined by fn() to the graph.
139-
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
140-
var original_result = fn();
141-
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
142-
143-
//TODO: port this chunck of missing code:
144-
/*
145-
if len(post_summaries) > len(pre_summaries):
146-
new_summaries = post_summaries[len(pre_summaries):]
147-
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
148-
summary_ref[:] = pre_summaries
149-
with ops.control_dependencies(new_summaries):
150-
if original_result is None:
151-
return no_op(), None
152-
else:
153-
original_result = nest.map_structure(array_ops.identity,
154-
original_result)
155-
*/
156-
if (original_result == null)
157-
return (original_result, null);
158-
159-
switch (original_result)
160-
{
161-
case Tensor result:
162-
return (original_result, _BuildCondTensor(result));
163-
case Operation op:
164-
return (original_result, _BuildCondTensor(op));
137+
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
138+
{
139+
// Add the subgraph defined by fn() to the graph.
140+
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
141+
var original_result = fn();
142+
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
143+
144+
//TODO: port this chunck of missing code:
145+
/*
146+
if len(post_summaries) > len(pre_summaries):
147+
new_summaries = post_summaries[len(pre_summaries):]
148+
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
149+
summary_ref[:] = pre_summaries
150+
with ops.control_dependencies(new_summaries):
151+
if original_result is None:
152+
return no_op(), None
153+
else:
154+
original_result = nest.map_structure(array_ops.identity,
155+
original_result)
156+
*/
157+
if (original_result == null)
158+
return (original_result, null);
159+
160+
switch (original_result)
161+
{
162+
case Tensor result:
163+
return (original_result, _BuildCondTensor(result));
164+
case Operation op:
165+
return (original_result, _BuildCondTensor(op));
165166
case float[] fv:
166167
{
167168
var result = ops.convert_to_tensor(fv[0]);
168169
return (original_result, _BuildCondTensor(result));
169-
}
170-
default:
171-
return (original_result, null);
172-
}
173-
}
174-
175-
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
176-
{
177-
// Add the subgraph defined by fn() to the graph.
178-
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
179-
var original_result = fn();
180-
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
181-
182-
switch (original_result)
183-
{
184-
case Tensor[] results:
185-
return (original_result, results.Select(_BuildCondTensor).ToArray());
186-
case Operation[] results:
187-
return (original_result, results.Select(_BuildCondTensor).ToArray());
188-
case float[] fv:
189-
var result = ops.convert_to_tensor(fv[0]);
190-
return (original_result, new Tensor[] { result });
191-
default:
192-
return (original_result, new Tensor[0]);
193-
}
194-
}
195-
196-
private Tensor _BuildCondTensor(ITensorOrOperation v)
197-
{
198-
switch (v)
199-
{
200-
case Operation op:
201-
// Use pivot as the proxy for this op.
202-
return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot);
203-
case Tensor t:
204-
return _ProcessOutputTensor(t);
205-
default:
206-
return _ProcessOutputTensor(ops.convert_to_tensor(v));
207-
208-
}
209-
}
210-
170+
}
171+
default:
172+
return (original_result, null);
173+
}
174+
}
175+
176+
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
177+
{
178+
// Add the subgraph defined by fn() to the graph.
179+
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
180+
var original_result = fn();
181+
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
182+
183+
switch (original_result)
184+
{
185+
case Tensor[] results:
186+
return (original_result, results.Select(_BuildCondTensor).ToArray());
187+
case Operation[] results:
188+
return (original_result, results.Select(_BuildCondTensor).ToArray());
189+
case float[] fv:
190+
var result = ops.convert_to_tensor(fv[0]);
191+
return (original_result, new Tensor[] { result });
192+
default:
193+
return (original_result, new Tensor[0]);
194+
}
195+
}
196+
197+
private Tensor _BuildCondTensor(ITensorOrOperation v)
198+
{
199+
switch (v)
200+
{
201+
case Operation op:
202+
// Use pivot as the proxy for this op.
203+
return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot);
204+
case Tensor t:
205+
return _ProcessOutputTensor(t);
206+
default:
207+
return _ProcessOutputTensor(ops.convert_to_tensor(v));
208+
209+
}
210+
}
211+
211212
/// <summary>
212213
/// Process an output tensor of a conditional branch.
213214
/// </summary>
@@ -238,7 +239,7 @@ private Tensor _ProcessOutputTensor(Tensor val)
238239
}
239240
return real_val;
240241
}
241-
242+
242243
protected override void _AddOpInternal(Operation op)
243244
{
244245
if (op.inputs.Length == 0)
@@ -324,20 +325,20 @@ public override bool back_prop
324325
}
325326
}
326327

327-
public CondContextDef to_proto(string export_scope)
328-
{
329-
throw new NotImplementedException();
330-
}
331-
332-
public CondContext from_proto(CondContextDef proto, string import_scope)
333-
{
334-
var ret = new CondContext(context_def: proto, import_scope: import_scope);
335-
336-
ret.Enter();
337-
foreach (var nested_def in proto.NestedContexts)
338-
from_control_flow_context_def(nested_def, import_scope: import_scope);
339-
ret.Exit();
340-
return ret;
341-
}
342-
}
328+
public CondContextDef to_proto(string export_scope)
329+
{
330+
throw new NotImplementedException();
331+
}
332+
333+
public CondContext from_proto(CondContextDef proto, string import_scope)
334+
{
335+
var ret = new CondContext(context_def: proto, import_scope: import_scope);
336+
337+
ret.Enter();
338+
foreach (var nested_def in proto.NestedContexts)
339+
from_control_flow_context_def(nested_def, import_scope: import_scope);
340+
ret.Exit();
341+
return ret;
342+
}
343+
}
343344
}

0 commit comments

Comments
 (0)