Skip to content

Commit cb7e017

Browse files
committed
mean_absolute_percentage_error
1 parent 492d6e1 commit cb7e017

4 files changed

Lines changed: 15 additions & 3 deletions

File tree

src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
using System.Collections.Generic;
2020
using System.Text;
2121
using static Tensorflow.Binding;
22+
using System.Linq;
2223

2324
namespace Tensorflow
2425
{
@@ -62,5 +63,8 @@ public Tensor this[params Slice[] slices]
6263
});
6364
}
6465
}
66+
67+
public Tensor this[params string[] slices]
68+
=> this[slices.Select(x => new Slice(x)).ToArray()];
6569
}
6670
}

src/TensorFlowNET.Keras/Engine/MetricsContainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
7777
metric_obj = keras.metrics.categorical_accuracy;
7878
}
7979
else if(metric == "mean_absolute_error" || metric == "mae")
80-
{
8180
metric_obj = keras.metrics.mean_absolute_error;
82-
}
81+
else if (metric == "mean_absolute_percentage_error" || metric == "mape")
82+
metric_obj = keras.metrics.mean_absolute_percentage_error;
8383
else
8484
throw new NotImplementedException("");
8585

src/TensorFlowNET.Keras/Engine/Model.Compile.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ public void compile(string optimizer, string loss, string[] metrics)
4242
_ => throw new NotImplementedException("")
4343
};
4444

45-
var _loss = loss switch
45+
ILossFunc _loss = loss switch
4646
{
4747
"mse" => new MeanSquaredError(),
48+
"mae" => new MeanAbsoluteError(),
4849
_ => throw new NotImplementedException("")
4950
};
5051

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,12 @@ public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred)
4646
y_true = math_ops.cast(y_true, y_pred.dtype);
4747
return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1);
4848
}
49+
50+
public Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred)
51+
{
52+
y_true = math_ops.cast(y_true, y_pred.dtype);
53+
var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon());
54+
return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1);
55+
}
4956
}
5057
}

0 commit comments

Comments
 (0)