forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLogCosh.cs
More file actions
20 lines (17 loc) · 741 Bytes
/
LogCosh.cs
File metadata and controls
20 lines (17 loc) · 741 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
namespace Tensorflow.Keras.Losses;
public class LogCosh : LossFunctionWrapper
{
public LogCosh(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "log_cosh" : name)
{ }
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor x = y_pred_dispatch - y_true_cast;
return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
ops.convert_to_tensor(-1));
}
}