|
| 1 | +using System; |
| 2 | +using System.Collections.Generic; |
| 3 | +using System.Text; |
| 4 | + |
| 5 | +namespace Tensorflow |
| 6 | +{ |
| 7 | + /// <summary> |
| 8 | + /// Variable store that carries a number of named Variables. |
| 9 | + /// </summary> |
| 10 | + public class _VariableStore |
| 11 | + { |
| 12 | + private Dictionary<string, object> _vars; |
| 13 | + private Dictionary<string, object> _partitioned_vars; |
| 14 | + private bool _store_eager_variables; |
| 15 | + |
| 16 | + public _VariableStore() |
| 17 | + { |
| 18 | + _vars = new Dictionary<string, object>(); |
| 19 | + _partitioned_vars = new Dictionary<string, object>(); |
| 20 | + _store_eager_variables = false; |
| 21 | + } |
| 22 | + |
| 23 | + public RefVariable get_variable(string name, |
| 24 | + TensorShape shape = null, |
| 25 | + TF_DataType dtype = TF_DataType.TF_FLOAT, |
| 26 | + IInitializer initializer = null, |
| 27 | + bool trainable = false, |
| 28 | + bool validate_shape = true, |
| 29 | + VariableSynchronization synchronization = VariableSynchronization.AUTO, |
| 30 | + VariableAggregation aggregation = VariableAggregation.NONE) |
| 31 | + { |
| 32 | + dtype = dtype.as_base_dtype(); |
| 33 | + trainable = variable_scope._get_trainable_value(synchronization, trainable); |
| 34 | + |
| 35 | + return _true_getter(name, |
| 36 | + shape: shape, |
| 37 | + dtype: dtype, |
| 38 | + initializer: initializer, |
| 39 | + trainable: trainable, |
| 40 | + validate_shape: validate_shape, |
| 41 | + synchronization: synchronization, |
| 42 | + aggregation: aggregation); |
| 43 | + } |
| 44 | + |
| 45 | + private RefVariable _true_getter(string name, |
| 46 | + TensorShape shape = null, |
| 47 | + TF_DataType dtype = TF_DataType.DtInvalid, |
| 48 | + IInitializer initializer = null, |
| 49 | + bool trainable = false, |
| 50 | + bool validate_shape = true, |
| 51 | + VariableSynchronization synchronization = VariableSynchronization.AUTO, |
| 52 | + VariableAggregation aggregation = VariableAggregation.NONE) |
| 53 | + { |
| 54 | + return _get_single_variable(name: name); |
| 55 | + } |
| 56 | + |
| 57 | + private RefVariable _get_single_variable(string name, |
| 58 | + TensorShape shape = null, |
| 59 | + TF_DataType dtype = TF_DataType.DtInvalid, |
| 60 | + IInitializer initializer = null, |
| 61 | + bool reuse = false, |
| 62 | + bool trainable = false, |
| 63 | + bool validate_shape = false, |
| 64 | + VariableSynchronization synchronization = VariableSynchronization.AUTO, |
| 65 | + VariableAggregation aggregation = VariableAggregation.NONE) |
| 66 | + { |
| 67 | + if (_vars.ContainsKey(name)) |
| 68 | + { |
| 69 | + if (!reuse) |
| 70 | + { |
| 71 | + var var = _vars[name]; |
| 72 | + |
| 73 | + } |
| 74 | + throw new NotImplementedException("_get_single_variable"); |
| 75 | + } |
| 76 | + |
| 77 | + throw new NotImplementedException("_get_single_variable"); |
| 78 | + } |
| 79 | + } |
| 80 | +} |
0 commit comments