Skip to content

Commit 6a9ccea

Browse files
committed
Resolve some wrong implementations.
1 parent 4c1878b commit 6a9ccea

15 files changed

Lines changed: 114 additions & 62 deletions

File tree

src/TensorFlowNET.Core/Buffers/TF_Buffer.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,32 @@ public struct TF_Buffer
2525
public IntPtr data;
2626
public ulong length;
2727
public IntPtr data_deallocator;
28+
29+
public unsafe Span<T> AsSpan<T>() where T: unmanaged
30+
{
31+
if(length > int.MaxValue)
32+
{
33+
throw new ValueError($"The length {length} is too large to use in the span.");
34+
}
35+
return new Span<T>(data.ToPointer(), (int)length);
36+
}
37+
38+
public unsafe byte[] ToByteArray()
39+
{
40+
byte[] res = new byte[length];
41+
if(length > int.MaxValue)
42+
{
43+
byte* root = (byte*)data;
44+
for(ulong i = 0; i < length; i++)
45+
{
46+
res[i] = *(root++);
47+
}
48+
}
49+
else
50+
{
51+
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan());
52+
}
53+
return res;
54+
}
2855
}
2956
}

src/TensorFlowNET.Core/Eager/execute.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] valu
1818
var types = v.Select(t => t.dtype.as_datatype_enum());
1919
return (types.ToArray(), v.ToArray());
2020
}
21+
public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
22+
{
23+
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name);
24+
}
2125
public static Tensor[] quick_execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null)
2226
{
2327
string device_name = ctx.DeviceName;

src/TensorFlowNET.Core/Framework/importer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ private static void _ProcessNewOps(Graph graph)
149149
foreach (var new_op in graph._add_new_tf_operations())
150150
{
151151
var original_device = new_op.Device;
152+
new_op._set_device(original_device);
152153
}
153154
}
154155

src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
using Google.Protobuf;
22
using System;
33
using System.Collections.Generic;
4+
using System.IO;
45
using System.Linq;
56
using System.Text;
67
using Tensorflow.Contexts;
8+
using Tensorflow.Eager;
79
using Tensorflow.Graphs;
810
using Tensorflow.Operations;
911
using Tensorflow.Util;
@@ -16,6 +18,8 @@ public class EagerDefinedFunction
1618
public int _num_outputs;
1719
FuncGraph _func_graph;
1820
FunctionDef _definition;
21+
OpDef _signature;
22+
string _name;
1923
Tensor[] _func_graph_outputs;
2024
public string Name => _func_graph.FuncName;
2125
public DataType[] OutputTypes { get; protected set; }
@@ -31,6 +35,18 @@ public FunctionDef Definition
3135
return _definition;
3236
}
3337
}
38+
39+
public OpDef Signature
40+
{
41+
get
42+
{
43+
if( _signature is null)
44+
{
45+
_signature = Definition.Signature;
46+
}
47+
return _signature;
48+
}
49+
}
3450
public EagerDefinedFunction(string name, FuncGraph graph,
3551
Tensors inputs, Tensors outputs,
3652
Dictionary<string, string> attrs)
@@ -75,12 +91,12 @@ public Tensors Call(Tensors args)
7591
Tensor[] outputs;
7692
if (executing_eagerly)
7793
{
78-
outputs = tf.Runner.TFE_Execute(tf.Context,
79-
tf.Context.DeviceName,
80-
_func_graph.FuncName,
81-
args,
82-
attrs,
83-
_num_outputs);
94+
outputs = execute.executes(
95+
Signature.Name,
96+
_num_outputs,
97+
args,
98+
attrs,
99+
tf.Context);
84100
}
85101
else
86102
{
@@ -135,9 +151,13 @@ public void AddToGraph(Graph g = null)
135151
private FunctionDef _get_definition()
136152
{
137153
var buffer = c_api_util.tf_buffer();
138-
// TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef
154+
Status status = new();
155+
c_api.TF_FunctionToFunctionDef(_func_graph._func_graph_handle, buffer, status);
156+
status.Check(true);
139157
var proto_data = c_api.TF_GetBuffer(buffer);
140-
throw new NotImplementedException();
158+
FunctionDef function_def = new();
159+
function_def.MergeFrom(proto_data.AsSpan<byte>());
160+
return function_def;
141161
}
142162
}
143163
}

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Tensorflow.Graphs;
1010
/// </summary>
1111
public class FuncGraph : Graph, IDisposable
1212
{
13-
SafeFuncGraphHandle _func_graph_handle;
13+
internal SafeFuncGraphHandle _func_graph_handle;
1414
public string FuncName => _graph_key;
1515

1616
public Tensors Inputs { get; set; } = new Tensors();

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,19 @@ public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
238238
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
239239
}
240240

241+
[Obsolete("The implementation is not complete.")]
242+
internal void _set_device_from_string(string device_str)
243+
{
244+
// TODO(Rinne): complete it with new C API `SetRequestedDevice`.
245+
//c_api.TF_SetDevice(_handle, device_str);
246+
}
247+
248+
[Obsolete("The implementation is not complete.")]
249+
internal void _set_device(string device)
250+
{
251+
_set_device_from_string(device);
252+
}
253+
241254
private NodeDef GetNodeDef()
242255
{
243256
var buffer = new Buffer();

src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,8 @@ public Loader(SavedObjectGraph object_graph_proto, SavedModel saved_model_proto,
4545
_asset_file_def = meta_graph.AssetFileDef;
4646
_operation_attributes = meta_graph.GraphDef.Node.ToDictionary(x => x.Name, x => x.Attr);
4747
_proto = object_graph_proto;
48-
// Debug(Rinne)
49-
var temp = _proto.ToString();
5048
_export_dir = export_dir;
51-
// TODO: `this._concrete_functions` and `this._restored_concrete_functions`
52-
// TODO(Rinne): This method is very slow, needs to be accelareted.
49+
// TODO(Rinne): This method is a bit slow (especially under debug mode), may need to be accelareted.
5350
_concrete_functions = function_deserialization.load_function_def_library(
5451
meta_graph.GraphDef.Library, _proto);
5552
_restored_concrete_functions = new HashSet<string>();
@@ -322,11 +319,6 @@ private void _load_checkpoint_save_and_restore_functions()
322319
foreach(var (node_id, proto) in _iter_all_nodes())
323320
{
324321
var node = get(node_id);
325-
if(node is null)
326-
{
327-
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
328-
continue;
329-
}
330322
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
331323
{
332324
// Restore Trackable serialize- and restore-from-tensor functions.
@@ -390,7 +382,7 @@ private void _load_nodes()
390382
var optimizer_object = nodes[optimizer_node_id];
391383
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId];
392384

393-
// TODO: implement it.
385+
// TODO(Rinne): implement it.
394386
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." +
395387
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
396388
}
@@ -508,21 +500,11 @@ public Trackable get(string node_id)
508500
/// <param name="node_id"></param>
509501
private void _add_object_graph_edges(SavedObject proto, int node_id)
510502
{
511-
// Debug(Rinne)
512-
if(node_id == 1)
513-
{
514-
Console.WriteLine();
515-
}
516503
var obj = _nodes[node_id];
517504
var setter = _node_setters[node_id];
518505

519506
foreach(var refer in proto.Children)
520507
{
521-
if(obj is null)
522-
{
523-
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
524-
continue;
525-
}
526508
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
527509
// TODO(Rinne): deal with "__call__"
528510
}
@@ -553,12 +535,6 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
553535
private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
554536
{
555537
// skip the registered classes.
556-
if(node_id == 16)
557-
{
558-
// Debug(Rinne)
559-
Console.WriteLine();
560-
}
561-
562538
Dictionary<OneOf<string, int>, Trackable> dependencies = new();
563539
foreach(var item in _get_node_dependencies(proto))
564540
{

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public BaseResourceVariable()
6565
}
6666

6767
public void __init__(bool trainable = true,
68+
Shape shape = null,
69+
TF_DataType dtype = TF_DataType.DtInvalid,
6870
Tensor handle = null,
6971
string name = null,
7072
string unique_id = null,
@@ -75,6 +77,14 @@ public void __init__(bool trainable = true,
7577
_unique_id = unique_id;
7678
this.handle = handle;
7779
_name = name;
80+
if(shape is not null)
81+
{
82+
_shape = shape;
83+
}
84+
if(dtype != TF_DataType.DtInvalid)
85+
{
86+
_dtype = dtype;
87+
}
7888

7989
// After the handle has been created, set up a way to clean it up when
8090
// executing eagerly. We'll hold the only reference to the deleter, so that

src/TensorFlowNET.Core/Variables/ResourceVariable.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ private void _init_from_args(object initial_value = null,
116116
}
117117
});
118118

119-
_shape = shape ?? _initial_value.shape;
119+
if(shape is null)
120+
{
121+
shape = _initial_value.shape;
122+
}
123+
dtype = _initial_value.dtype;
120124

121125
if (_in_graph_mode)
122126
{
@@ -135,7 +139,7 @@ private void _init_from_args(object initial_value = null,
135139
{
136140
handle = resource_variable_ops.eager_safe_variable_handle(
137141
initial_value: _initial_value,
138-
shape: _shape,
142+
shape: shape,
139143
shared_name: shared_name,
140144
name: name,
141145
graph_mode: _in_graph_mode);
@@ -154,6 +158,8 @@ private void _init_from_args(object initial_value = null,
154158
}
155159

156160
base.__init__(trainable: trainable,
161+
shape: shape,
162+
dtype: dtype,
157163
handle: handle,
158164
name: name,
159165
unique_id: unique_id,

src/TensorFlowNET.Core/Variables/UninitializedVariable.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ public UninitializedVariable(
5050
{
5151
tf_with(ops.name_scope("Read"), _ =>
5252
{
53-
tf.device(handle.Device);
54-
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
55-
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
53+
tf.device(created_handle.Device);
54+
var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype);
55+
resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value);
5656
_graph_element = value;
5757
});
5858
ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this);
@@ -63,9 +63,7 @@ public UninitializedVariable(
6363
}
6464
});
6565
});
66-
_shape = shape;
67-
_dtype = dtype;
68-
base.__init__(trainable, created_handle, unique_id: unique_id, handle_name: handle_name);
66+
base.__init__(trainable, shape, dtype, created_handle, unique_id: unique_id, handle_name: handle_name);
6967
}
7068
}
7169
}

0 commit comments

Comments
 (0)