forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMetric.cs
More file actions
69 lines (61 loc) · 2.16 KB
/
Metric.cs
File metadata and controls
69 lines (61 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Metrics
{
/// <summary>
/// Encapsulates metric logic and state.
/// </summary>
public class Metric : Layer, IMetricFunc
{
protected IVariableV1 total;
protected IVariableV1 count;
protected string _reduction;
protected TF_DataType _dtype;
public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
: base(new LayerArgs
{
Name = name,
DType = dtype
})
{
stateful = true;
built = true;
}
protected override IVariableV1 add_weight(string name,
Shape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
IRegularizer regularizer = null,
VariableSynchronization synchronization = VariableSynchronization.OnRead,
VariableAggregation aggregation = VariableAggregation.Sum,
bool trainable = true,
Func<VariableArgs, IVariableV1> getter = null)
{
if (shape == null)
shape = new Shape(new int[0]);
return tf_with(ops.init_scope(), delegate
{
return base.add_weight(name, shape,
dtype: dtype,
trainable: false,
initializer: initializer,
synchronization: synchronization,
aggregation: aggregation);
});
}
public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
=> throw new NotImplementedException("");
public virtual void reset_states()
{
foreach (var v in Weights)
v.assign(0);
}
public virtual Tensor result()
=> throw new NotImplementedException("");
public override string ToString()
=> $"{name} {(float)total.numpy()}/{(float)count.numpy()}";
}
}