Skip to content

Commit ca9f574

Browse files
committed
Add metric of top_k_categorical_accuracy.
1 parent b8645d3 commit ca9f574

9 files changed

Lines changed: 117 additions & 12 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ public Tensor erf(Tensor x, string name = null)
3939
public Tensor sum(Tensor x, Axis? axis = null, string name = null)
4040
=> math_ops.reduce_sum(x, axis: axis, name: name);
4141

42+
public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK")
43+
=> nn_ops.in_top_k(predictions, targets, k, name);
44+
4245
/// <summary>
4346
///
4447
/// </summary>

src/TensorFlowNET.Core/Keras/IKerasApi.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
using System.Text;
44
using Tensorflow.Keras.Layers;
55
using Tensorflow.Keras.Losses;
6+
using Tensorflow.Keras.Metrics;
67

78
namespace Tensorflow.Keras
89
{
910
public interface IKerasApi
1011
{
1112
public ILayersApi layers { get; }
1213
public ILossesApi losses { get; }
14+
public IMetricsApi metrics { get; }
1315
public IInitializersApi initializers { get; }
1416
}
1517
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public interface IMetricsApi
4+
{
5+
Tensor binary_accuracy(Tensor y_true, Tensor y_pred);
6+
7+
Tensor categorical_accuracy(Tensor y_true, Tensor y_pred);
8+
9+
Tensor mean_absolute_error(Tensor y_true, Tensor y_pred);
10+
11+
Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred);
12+
13+
/// <summary>
14+
/// Calculates how often predictions matches integer labels.
15+
/// </summary>
16+
/// <param name="y_true">Integer ground truth values.</param>
17+
/// <param name="y_pred">The prediction values.</param>
18+
/// <returns>Sparse categorical accuracy values.</returns>
19+
Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred);
20+
21+
/// <summary>
22+
/// Computes how often targets are in the top `K` predictions.
23+
/// </summary>
24+
/// <param name="y_true"></param>
25+
/// <param name="y_pred"></param>
26+
/// <param name="k"></param>
27+
/// <returns></returns>
28+
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
29+
}

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,8 @@ public static Tensor log_softmax(Tensor logits, string name = null)
240240
/// <param name="name"></param>
241241
/// <returns>A `Tensor` of type `bool`.</returns>
242242
public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null)
243-
{
244-
var _op = tf.OpDefLib._apply_op_helper("InTopKV2", name: name, args: new
245-
{
246-
predictions,
247-
targets,
248-
k
249-
});
250-
251-
return _op.output;
252-
}
243+
=> tf.Context.ExecuteOp("InTopKV2", name,
244+
new ExecuteOpArgs(predictions, targets, k));
253245

254246
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
255247
=> tf.Context.ExecuteOp("LeakyRelu", name,

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
121121
if (dtype == TF_DataType.TF_INT32)
122122
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
123123
}
124+
else if (values is double[] double_values)
125+
{
126+
if (dtype == TF_DataType.TF_FLOAT)
127+
values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray();
128+
}
124129
else
125130
values = Convert.ChangeType(values, new_system_dtype);
126131

src/TensorFlowNET.Keras/KerasInterface.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class KerasInterface : IKerasApi
2727
ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
2828
public BackendImpl backend => _backend.Value;
2929
public OptimizerApi optimizers { get; } = new OptimizerApi();
30-
public MetricsApi metrics { get; } = new MetricsApi();
30+
public IMetricsApi metrics { get; } = new MetricsApi();
3131
public ModelsApi models { get; } = new ModelsApi();
3232
public KerasUtils utils { get; } = new KerasUtils();
3333

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
namespace Tensorflow.Keras.Metrics
44
{
5-
public class MetricsApi
5+
public class MetricsApi : IMetricsApi
66
{
77
public Tensor binary_accuracy(Tensor y_true, Tensor y_pred)
88
{
@@ -53,5 +53,12 @@ public Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred)
5353
var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon());
5454
return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1);
5555
}
56+
57+
public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5)
58+
{
59+
return metrics_utils.sparse_top_k_categorical_matches(
60+
tf.math.argmax(y_true, axis: -1), y_pred, k
61+
);
62+
}
5663
}
5764
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Tensorflow.NumPy;
2+
3+
namespace Tensorflow.Keras.Metrics;
4+
5+
public class metrics_utils
6+
{
7+
public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5)
8+
{
9+
var reshape_matches = false;
10+
var y_true_rank = y_true.shape.ndim;
11+
var y_pred_rank = y_pred.shape.ndim;
12+
var y_true_org_shape = tf.shape(y_true);
13+
14+
if (y_pred_rank > 2)
15+
{
16+
y_pred = tf.reshape(y_pred, (-1, y_pred.shape[-1]));
17+
}
18+
19+
if (y_true_rank > 1)
20+
{
21+
reshape_matches = true;
22+
y_true = tf.reshape(y_true, new Shape(-1));
23+
}
24+
25+
var matches = tf.cast(
26+
tf.math.in_top_k(
27+
predictions: y_pred, targets: tf.cast(y_true, np.int32), k: k
28+
),
29+
dtype: keras.backend.floatx()
30+
);
31+
32+
if (reshape_matches)
33+
{
34+
return tf.reshape(matches, shape: y_true_org_shape);
35+
}
36+
37+
return matches;
38+
}
39+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using System.Threading.Tasks;
7+
using Tensorflow;
8+
using Tensorflow.NumPy;
9+
using static Tensorflow.Binding;
10+
using static Tensorflow.KerasApi;
11+
12+
namespace TensorFlowNET.Keras.UnitTest;
13+
14+
[TestClass]
15+
public class MetricsTest : EagerModeTestBase
16+
{
17+
/// <summary>
18+
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/top_k_categorical_accuracy
19+
/// </summary>
20+
[TestMethod]
21+
public void top_k_categorical_accuracy()
22+
{
23+
var y_true = np.array(new[,] { { 0, 0, 1 }, { 0, 1, 0 } });
24+
var y_pred = np.array(new[,] { { 0.1f, 0.9f, 0.8f }, { 0.05f, 0.95f, 0f } });
25+
var m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k: 3);
26+
Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
27+
}
28+
}

0 commit comments

Comments
 (0)