Skip to content

Commit 500f0c0

Browse files
committed
Fix preserve_cardinality for ParallelMapDataset.
1 parent e67b2a5 commit 500f0c0

10 files changed

Lines changed: 81 additions & 23 deletions

File tree

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ public IDatasetV2 map(Func<Tensors, Tensors> map_func,
6666
use_legacy_function: use_legacy_function);
6767

6868
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
69-
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
69+
=> new ParallelMapDataset(this, map_func,
70+
num_parallel_calls: num_parallel_calls,
71+
preserve_cardinality: true);
7072

7173
public OwnedIterator make_one_shot_iterator()
7274
{

src/TensorFlowNET.Core/Data/ParallelMapDataset.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,26 @@ public ParallelMapDataset(IDatasetV2 input_dataset,
1515
bool preserve_cardinality = false,
1616
bool use_legacy_function = false) : base(input_dataset)
1717
{
18-
var func = new ConcreteFunction(map_func,
19-
input_dataset.element_spec.Select(x => x.dtype).ToArray(),
20-
input_dataset.element_spec.Select(x => x.shape).ToArray());
18+
var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}");
19+
func.Enter();
20+
var inputs = new Tensors();
21+
foreach (var input in input_dataset.element_spec)
22+
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
23+
var outputs = map_func(inputs);
24+
func.ToGraph(inputs, outputs);
25+
func.Exit();
2126

2227
structure = func.OutputStructure;
28+
2329
var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64,
2430
name: "num_parallel_calls");
2531
variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor,
2632
_num_parallel_calls,
2733
func,
2834
output_types,
29-
output_shapes);
35+
output_shapes,
36+
use_inter_op_parallelism: use_inter_op_parallelism,
37+
preserve_cardinality: preserve_cardinality);
3038
}
3139
}
3240
}

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
7171
func_graph.Exit();
7272
}
7373

74-
public ConcreteFunction(Func<Tensors, Tensors> func,
74+
/*public ConcreteFunction(Func<Tensors, Tensors> func,
7575
TF_DataType[] dtypes, TensorShape[] shapes)
7676
{
7777
string func_name = $"{func.Method.Name}_{ops.uid_function()}";
@@ -89,7 +89,7 @@ public ConcreteFunction(Func<Tensors, Tensors> func,
8989
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
9090
func_graph.ToGraph(opers, inputs, Outputs, null);
9191
func_graph.Exit();
92-
}
92+
}*/
9393

9494
public void ToGraph(Tensors inputs, Tensors outputs)
9595
{

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ public Tensor this[int index]
3838
}
3939
}
4040

41+
public Tensor this[params string[] slices]
42+
=> items.First()[slices];
4143
public Tensors(params Tensor[] tensors)
4244
{
4345
items.AddRange(tensors);

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,10 @@ public static string to_numpy_string(Tensor tensor)
585585
else
586586
return $"['{string.Join("', '", tensor.StringData().Take(25))}']";
587587
}
588+
else if(dtype == TF_DataType.TF_VARIANT)
589+
{
590+
return "<unprintable>";
591+
}
588592

589593
var nd = tensor.numpy();
590594

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
100100
using var data_iterator = new OwnedIterator(_dataset);
101101
yield return (epoch, data_iterator);
102102
}
103+
// _adapter.on_epoch_end()
103104
}
104105

105106
public IEnumerable<int> steps()

src/TensorFlowNET.Keras/Engine/Model.Compile.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,22 @@ public void compile(OptimizerV2 optimizer = null,
3333

3434
public void compile(string optimizer, string loss, string[] metrics)
3535
{
36-
switch (optimizer)
36+
var _optimizer = optimizer switch
3737
{
38-
case "rmsprop":
39-
this.optimizer = new RMSprop(new RMSpropArgs
40-
{
38+
"rmsprop" => new RMSprop(new RMSpropArgs
39+
{
4140

42-
});
43-
break;
44-
}
41+
}),
42+
_ => throw new NotImplementedException("")
43+
};
4544

46-
int experimental_steps_per_execution = 1;
47-
_configure_steps_per_execution(experimental_steps_per_execution);
48-
49-
_reset_compile_cache();
45+
var _loss = loss switch
46+
{
47+
"mse" => new MeanSquaredError(),
48+
_ => throw new NotImplementedException("")
49+
};
5050

51-
_is_compiled = true;
51+
compile(optimizer: _optimizer, loss: _loss, metrics: metrics);
5252
}
5353
}
5454
}

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,32 @@ public void evaluate(NDArray x, NDArray y,
4949
Binding.tf_output_redirect.WriteLine($"Testing...");
5050
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
5151
{
52-
// reset_metrics();
52+
reset_metrics();
53+
// callbacks.on_epoch_begin(epoch)
54+
// data_handler.catch_stop_iteration();
55+
IEnumerable<(string, Tensor)> results = null;
56+
foreach (var step in data_handler.steps())
57+
{
58+
// callbacks.on_train_batch_begin(step)
59+
results = test_function(iterator);
60+
}
61+
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
62+
}
63+
}
64+
65+
public void evaluate(IDatasetV2 x)
66+
{
67+
data_handler = new DataHandler(new DataHandlerArgs
68+
{
69+
Dataset = x,
70+
Model = this,
71+
StepsPerExecution = _steps_per_execution
72+
});
73+
74+
Binding.tf_output_redirect.WriteLine($"Testing...");
75+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
76+
{
77+
reset_metrics();
5378
// callbacks.on_epoch_begin(epoch)
5479
// data_handler.catch_stop_iteration();
5580
IEnumerable<(string, Tensor)> results = null;

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,11 @@ public IDatasetV2 timeseries_dataset_from_array(Tensor data, int sequence_length
124124

125125
var start_positions_tensor = tf.constant(start_positions);
126126
var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat();
127-
var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds);
127+
var r = tf.data.Dataset.range(len(start_positions));
128+
var z = tf.data.Dataset.zip(r, positions_ds);
128129
var indices = z.map(m =>
129130
{
130-
var (i, positions) = (m[0], m[1]);
131+
var (i, positions) = m;
131132
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
132133
}, num_parallel_calls: -1);
133134
var dataset = sequences_from_indices(data, indices, start_index, end_index);
@@ -142,7 +143,11 @@ IDatasetV2 sequences_from_indices(Tensor array, IDatasetV2 indices_ds, int start
142143
{
143144
var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]);
144145
dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds)
145-
.map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1);
146+
.map(x =>
147+
{
148+
var (steps, indx) = x;
149+
return array_ops.gather(steps, indx);
150+
}, num_parallel_calls: -1);
146151
return dataset;
147152
}
148153
}

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,18 @@ public void Cache()
147147
public void Cardinality()
148148
{
149149
var dataset = tf.data.Dataset.range(10);
150+
var cardinality = dataset.dataset_cardinality();
151+
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
150152
dataset = dataset.map(x => x[0] + 1);
153+
cardinality = dataset.dataset_cardinality();
154+
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
155+
}
156+
157+
[TestMethod]
158+
public void CardinalityWithAutoTune()
159+
{
160+
var dataset = tf.data.Dataset.range(10);
161+
dataset = dataset.map(x => x, num_parallel_calls: -1);
151162
var cardinality = dataset.dataset_cardinality();
152163
Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
153164
}

0 commit comments

Comments
 (0)