forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLoss.cs
More file actions
51 lines (43 loc) · 1.37 KB
/
Loss.cs
File metadata and controls
51 lines (43 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Losses;
/// <summary>
/// Loss base class.
/// </summary>
public abstract class Loss : ILossFunc
{
protected string reduction;
protected string name;
bool _allow_sum_over_batch_size;
protected bool from_logits = false;
string _name_scope;
public string Reduction => reduction;
public string Name => name;
public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
this.name = name;
this.from_logits = from_logits;
_allow_sum_over_batch_size = false;
}
public abstract Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1);
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
var reduction = GetReduction();
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
}
string GetReduction()
{
return reduction switch
{
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
_ => reduction
};
}
void _set_name_scope()
{
_name_scope = name;
}
}