forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbackend.cs
More file actions
128 lines (112 loc) · 5.39 KB
/
backend.cs
File metadata and controls
128 lines (112 loc) · 5.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using System;
using System.Collections.Generic;
using static Tensorflow.Binding;
namespace Tensorflow.Keras
{
public class backend : BackendBase
{
/* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */
public static Func<Array, double> py_sum = sum;
public static Func<Array, bool> py_all = all;
//Func<Array, bool> py_any = any;
//Func<double, double, double, IEnumerable<double>> py_slice = slice;
public static Session _SESSION = ops.get_default_session();
public static Graph _GRAPH = null;
public static Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
//Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
public static bool _MANUAL_VAR_INIT = false;
public static List<string> _LOCAL_DEVICES = null;
/* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */
/// <summary>
/// A global dictionary mapping graph objects to an index of counters used
/// for various layer names in each graph.
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
/// </summary>
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static Dictionary<string, VariableV1> _GRAPH_VARIABLES = new Dictionary<string, VariableV1>();
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();
public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();
public static void track_variable(VariableV1 v)
{
var graph = v.graph;
_GRAPH_VARIABLES[graph.graph_key] = v;
}
public static Tensor placeholder(int[] shape = null,
int ndim = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
string name = null)
{
if (sparse)
{
throw new NotImplementedException("placeholder sparse is true");
}
else
{
return gen_array_ops.placeholder(dtype: dtype, shape: new TensorShape(shape), name: name);
}
}
public static Graph get_graph()
{
return ops.get_default_graph();
}
public static int get_uid(string prefix, string @namespace = "")
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1;
return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)];
}
public static int get_uid((string, string) name)
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1;
return PER_GRAPH_LAYER_NAME_UIDS[graph][name];
}
public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static void clear_session()
{
ops.reset_default_graph();
reset_uids();
_SESSION = null;
var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>();
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
}
public static void manual_variable_initialization(bool value)
{
_MANUAL_VAR_INIT = value;
}
public static GraphLearningPhase learning_phase()
{
var graph = tf.get_default_graph();
if (_GRAPH_LEARNING_PHASES.ContainsKey(graph))
{
var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES[graph] = 0;
}
return _GRAPH_LEARNING_PHASES[graph];
}
public static void set_learning_phase(bool value)
{
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
}
public class _DummyEagerGraph
{ }
}
}