forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFunction.cs
More file actions
84 lines (72 loc) · 2.64 KB
/
Function.cs
File metadata and controls
84 lines (72 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
using System;
using Tensorflow.Functions;
using Tensorflow.Train;
namespace Tensorflow
{
public class Function: Trackable, IGenericFunction
{
#pragma warning disable CS0169 // The field 'Function._handle' is never used
private IntPtr _handle;
#pragma warning restore CS0169 // The field 'Function._handle' is never used
protected Func<Tensor[], Tensor[]> _csharp_function;
protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _autograph;
protected TracingCompiler _variable_creation_fn;
public string Name { get; set; }
public Function(Func<Tensor[], Tensor[]> csharp_function,
string name, bool auto_graph = true)
{
_csharp_function = csharp_function;
Name = name;
_autograph = auto_graph;
}
public virtual Tensors Apply(Tensors inputs)
{
if (_run_functions_eagerly())
{
return _csharp_function(inputs);
}
var result = _call(inputs);
return result;
}
public ConcreteFunction get_concrete_function(params Tensor[] args)
{
return _get_concrete_function_garbage_collected(args);
}
protected virtual Tensors _call(Tensors inputs)
{
if(_variable_creation_fn is not null)
{
return _variable_creation_fn.Apply(inputs);
}
_initialize(inputs);
return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs);
}
protected TracingCompiler _compiler(Func<Tensor[], Tensor[]> fn)
{
var name = nameof(fn);
return new TracingCompiler(fn, name, autograph: _autograph);
}
protected virtual bool _run_functions_eagerly()
{
return false;
}
protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args)
{
if(_variable_creation_fn is null)
{
_initialize(args);
// TODO(Rinne): _initialize_uninitialized_variables
}
var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
return concrete;
}
private void _initialize(Tensor[] args)
{
_variable_creation_fn = _compiler(_csharp_function);
_variable_creation_fn._name = this.Name;
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
}
}
}