Skip to content

Commit d61ccf0

Browse files
committed
linear regression 2
1 parent 649623b commit d61ccf0

49 files changed

Lines changed: 312 additions & 151 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

data/linear_regression.zip

6 Bytes
Binary file not shown.

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static partial class tf
1717
/// A `Tensor` with the same data as `input`, but its shape has an additional
1818
/// dimension of size 1 added.
1919
/// </returns>
20-
public static Tensor expand_dims(Tensor input, int axis = -1, string name = "", int dim = -1)
20+
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
2121
=> array_ops.expand_dims(input, axis, name, dim);
2222
}
2323
}

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ namespace Tensorflow
66
{
77
public partial class tf
88
{
9-
public static Tensor read_file(string filename, string name = "") => gen_io_ops.read_file(filename, name);
9+
public static Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
1010

1111
public static gen_image_ops image => new gen_image_ops();
1212

1313
public static void import_graph_def(GraphDef graph_def,
1414
Dictionary<string, Tensor> input_map = null,
1515
string[] return_elements = null,
16-
string name = "",
16+
string name = null,
1717
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
1818
}
1919
}

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ public static partial class tf
1010

1111
public static Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b);
1212

13-
public static Tensor subtract<T>(Tensor x, T[] y, string name = "") where T : struct
13+
public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct
1414
=> gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name);
1515

1616
public static Tensor multiply(Tensor x, Tensor y) => gen_math_ops.mul(x, y);
1717

18-
public static Tensor divide<T>(Tensor x, T[] y, string name = "") where T : struct
18+
public static Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct
1919
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
2020

2121
public static Tensor pow<T1, T2>(T1 x, T2 y) => gen_math_ops.pow(x, y);
@@ -28,7 +28,7 @@ public static Tensor divide<T>(Tensor x, T[] y, string name = "") where T : stru
2828
/// <returns></returns>
2929
public static Tensor reduce_sum(Tensor input, int[] axis = null) => math_ops.reduce_sum(input);
3030

31-
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
31+
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
3232
=> math_ops.cast(x, dtype, name);
3333
}
3434
}

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ public static Tensor random_normal(int[] shape,
2121
float stddev = 1.0f,
2222
TF_DataType dtype = TF_DataType.TF_FLOAT,
2323
int? seed = null,
24-
string name = "") => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
24+
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
2525
}
2626
}

src/TensorFlowNET.Core/Eager/Execute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Tensorflow.Eager
66
{
77
public class Execute
88
{
9-
public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
9+
public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null)
1010
{
1111
pywrap_tfe_src.RecordGradient(op_name, inputs._inputs, attrs, results, name);
1212
}

src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Tensorflow.Eager
1010
/// </summary>
1111
public class pywrap_tfe_src
1212
{
13-
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
13+
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null)
1414
{
1515
var input_ids = inputs.Select(x => x.Id).ToArray();
1616
var input_dtypes = inputs.Select(x => x.dtype).ToArray();

src/TensorFlowNET.Core/Framework/importer.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class importer
1212
public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
1313
Dictionary<string, Tensor> input_map = null,
1414
string[] return_elements = null,
15-
string name = "",
15+
string name = null,
1616
OpList producer_op_list = null)
1717
{
1818
var op_dict = op_def_registry.get_registered_ops();

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_sco
123123
/// <param name="strip_default_attrs"></param>
124124
/// <param name="meta_info_def"></param>
125125
/// <returns></returns>
126-
public static MetaGraphDef export_scoped_meta_graph(string filename = "",
126+
public static (MetaGraphDef, Dictionary<string, RefVariable>) export_scoped_meta_graph(string filename = "",
127127
GraphDef graph_def = null,
128128
bool as_text = false,
129129
string unbound_inputs_col_name = "unbound_inputs",
@@ -138,7 +138,7 @@ public static MetaGraphDef export_scoped_meta_graph(string filename = "",
138138
var var_list = new Dictionary<string, RefVariable>();
139139
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES);
140140

141-
foreach(var v in variables as RefVariable[])
141+
foreach(var v in variables as List<RefVariable>)
142142
{
143143
var_list[v.name] = v;
144144
}
@@ -151,15 +151,18 @@ public static MetaGraphDef export_scoped_meta_graph(string filename = "",
151151
saver_def: saver_def,
152152
strip_default_attrs: strip_default_attrs);
153153

154-
throw new NotImplementedException("meta_graph.export_scoped_meta_graph");
154+
if (!string.IsNullOrEmpty(filename))
155+
graph_io.write_graph(scoped_meta_graph_def, "", filename, as_text: as_text);
156+
157+
return (scoped_meta_graph_def, var_list);
155158
}
156159

157160
private static bool _should_include_node()
158161
{
159162
return true;
160163
}
161164

162-
private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null,
165+
private static MetaGraphDef create_meta_graph_def(MetaInfoDef meta_info_def = null,
163166
GraphDef graph_def = null,
164167
string export_scope = "",
165168
string exclude_nodes = "",
@@ -168,7 +171,7 @@ private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null,
168171
bool strip_default_attrs = false)
169172
{
170173
// Sets graph to default graph if it's not passed in.
171-
var graph = ops.get_default_graph();
174+
var graph = ops.get_default_graph().as_default();
172175
// Creates a MetaGraphDef proto.
173176
var meta_graph_def = new MetaGraphDef();
174177
if (meta_info_def == null)
@@ -186,10 +189,55 @@ private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null,
186189
meta_graph_def.GraphDef = graph_def;
187190

188191
// Fills in meta_info_def.stripped_op_list using the ops from graph_def.
189-
if (meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
192+
if (meta_graph_def.MetaInfoDef.StrippedOpList == null ||
193+
meta_graph_def.MetaInfoDef.StrippedOpList.Op.Count == 0)
190194
meta_graph_def.MetaInfoDef.StrippedOpList = stripped_op_list_for_graph(meta_graph_def.GraphDef);
191195

192-
throw new NotImplementedException("create_meta_graph_def");
196+
var clist = graph.get_all_collection_keys();
197+
foreach(var ctype in clist)
198+
{
199+
if (clear_extraneous_savers)
200+
{
201+
throw new NotImplementedException("create_meta_graph_def clear_extraneous_savers");
202+
}
203+
else
204+
{
205+
add_collection_def(meta_graph_def, ctype, graph);
206+
}
207+
}
208+
209+
return meta_graph_def;
210+
}
211+
212+
private static void add_collection_def(MetaGraphDef meta_graph_def,
213+
string key,
214+
Graph graph = null,
215+
string export_scope = "")
216+
{
217+
if (!meta_graph_def.CollectionDef.ContainsKey(key))
218+
meta_graph_def.CollectionDef[key] = new CollectionDef();
219+
var col_def = meta_graph_def.CollectionDef[key];
220+
221+
switch (graph.get_collection(key))
222+
{
223+
case List<RefVariable> collection_list:
224+
col_def.BytesList = new Types.BytesList();
225+
foreach (var x in collection_list)
226+
{
227+
var proto = x.to_proto(export_scope);
228+
col_def.BytesList.Value.Add(proto.ToByteString());
229+
}
230+
231+
break;
232+
case List<object> collection_list:
233+
col_def.NodeList = new Types.NodeList();
234+
foreach (var x in collection_list)
235+
if (x is ITensorOrOperation x2)
236+
col_def.NodeList.Value.Add(ops.strip_name_scope(x2.name, export_scope));
237+
break;
238+
case List<Operation> collection_list:
239+
break;
240+
}
193241
}
194242

195243
private static OpList stripped_op_list_for_graph(GraphDef graph_def)

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tenso
118118

119119
if (obj is Tensor tensor && allow_tensor)
120120
{
121-
if (tensor.Graph.Equals(this))
121+
if (tensor.graph.Equals(this))
122122
{
123123
return tensor;
124124
}
@@ -164,7 +164,7 @@ private void _check_not_finalized()
164164
}
165165

166166
public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
167-
TF_DataType[] input_types = null, string name = "",
167+
TF_DataType[] input_types = null, string name = null,
168168
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
169169
{
170170
if (inputs == null)

0 commit comments

Comments
 (0)