Skip to content

Commit 2763f7c

Browse files
committed
fix tensor_slice_dataset.
1 parent 006eeaa commit 2763f7c

13 files changed

Lines changed: 205 additions & 20 deletions

File tree

src/TensorFlowNET.Core/Data/DatasetManager.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,18 @@ public IDatasetV2 from_generator<T>(IEnumerable<T> generator, TF_DataType[] outp
1919
public IDatasetV2 from_tensor(NDArray tensors)
2020
=> new TensorDataset(tensors);
2121

22+
public IDatasetV2 from_tensor(Tensor features, Tensor labels)
23+
=> new TensorDataset(features, labels);
24+
2225
public IDatasetV2 from_tensor(Tensor tensors)
2326
=> new TensorDataset(tensors);
2427

2528
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
2629
=> new TensorSliceDataset(features, labels);
2730

31+
public IDatasetV2 from_tensor_slices(Tensor tensor)
32+
=> new TensorSliceDataset(tensor);
33+
2834
public IDatasetV2 from_tensor_slices(string[] array)
2935
=> new TensorSliceDataset(array);
3036

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ public IDatasetV2 map(Func<Tensor, Tensor> map_func,
6060
preserve_cardinality: preserve_cardinality,
6161
use_legacy_function: use_legacy_function);
6262

63+
public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
64+
=> new FlatMapDataset(this, map_func);
65+
6366
public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
6467
=> new ModelDataset(this, algorithm, cpu_budget);
6568

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Functions;
5+
6+
namespace Tensorflow
7+
{
8+
/// <summary>
9+
///
10+
/// </summary>
11+
public class FlatMapDataset : UnaryDataset
12+
{
13+
public FlatMapDataset(IDatasetV2 input_dataset,
14+
Func<Tensor, IDatasetV2> map_func) : base(input_dataset)
15+
{
16+
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
17+
18+
variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor,
19+
func,
20+
output_types,
21+
output_shapes);
22+
}
23+
}
24+
}

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ IDatasetV2 map(Func<Tensor, Tensor> map_func,
6262
bool preserve_cardinality = false,
6363
bool use_legacy_function = false);
6464

65+
IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);
66+
6567
IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);
6668

6769
/// <summary>

src/TensorFlowNET.Core/Data/TensorDataset.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ namespace Tensorflow
1212
/// </summary>
1313
public class TensorDataset : DatasetSource
1414
{
15+
public TensorDataset(Tensor feature, Tensor label)
16+
{
17+
_tensors = new[] { feature, label };
18+
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
19+
structure = batched_spec.Select(x => x._unbatch()).ToArray();
20+
21+
variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
22+
23+
}
1524
public TensorDataset(Tensor element)
1625
{
1726
_tensors = new[] { element };

src/TensorFlowNET.Core/Data/TensorSliceDataset.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ public TensorSliceDataset(NDArray array)
3131
variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
3232
}
3333

34+
public TensorSliceDataset(Tensor tensor)
35+
{
36+
_tensors = new[] { tensor };
37+
var batched_spec = new[] { tensor.ToTensorSpec() };
38+
structure = batched_spec.Select(x => x._unbatch()).ToArray();
39+
40+
variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
41+
}
42+
3443
public TensorSliceDataset(Tensor features, Tensor labels)
3544
{
3645
_tensors = new[] { features, labels };

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,24 @@ public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
3333
}
3434
}
3535

36+
public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
37+
{
38+
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
39+
40+
// IntPtr func_handle;
41+
using (var graph = new FuncGraph(func_name))
42+
{
43+
var input = tf.placeholder(dtype);
44+
var output = func(input);
45+
46+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
47+
_handle = graph.ToGraph(opers,
48+
new Operation[] { input },
49+
new Operation[] { },
50+
null);
51+
}
52+
}
53+
3654
public Tensor Execute(Tensor arg)
3755
{
3856
var result = tf.Runner.TFE_Execute(tf.Context,
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.Engine;
45

56
namespace Tensorflow.Keras.ArgsDefinition
67
{
78
public class TensorLikeDataAdapterArgs
89
{
910
public Tensor X { get; set; }
1011
public Tensor Y { get; set; }
11-
public int BatchSize { get; set; }
12+
public int BatchSize { get; set; } = 32;
1213
public int Steps { get; set; }
1314
public int Epochs { get; set; }
1415
public bool Shuffle { get; set; }
16+
public int MaxQueueSize { get; set; }
17+
public int Worker { get; set; }
18+
public bool UseMultiprocessing { get; set; }
19+
public Model Model { get; set; }
1520
}
1621
}

src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,30 @@ namespace Tensorflow.Keras.Engine.DataAdapters
1111
public class DataHandler
1212
{
1313
DataHandlerArgs args;
14-
15-
Tensor x => args.X;
16-
Tensor y => args.Y;
17-
int batch_size => args.BatchSize;
18-
int steps_per_epoch => args.StepsPerEpoch;
19-
int initial_epoch => args.InitialEpoch;
20-
int epochs => args.Epochs;
21-
bool shuffle => args.Shuffle;
22-
int max_queue_size => args.MaxQueueSize;
23-
int workers => args.Workers;
24-
bool use_multiprocessing => args.UseMultiprocessing;
25-
Model model => args.Model;
26-
IVariableV1 steps_per_execution => args.StepsPerExecution;
14+
IDataAdapter _adapter;
2715

2816
public DataHandler(DataHandlerArgs args)
2917
{
3018
this.args = args;
3119

32-
var adapter_cls = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs { });
20+
_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs
21+
{
22+
X = args.X,
23+
Y = args.Y,
24+
BatchSize = args.BatchSize,
25+
Steps = args.StepsPerEpoch,
26+
Epochs = args.Epochs - args.InitialEpoch,
27+
Shuffle = args.Shuffle,
28+
MaxQueueSize = args.MaxQueueSize,
29+
Worker = args.Workers,
30+
UseMultiprocessing = args.UseMultiprocessing,
31+
Model = args.Model
32+
});
33+
}
34+
35+
Tensor _infer_steps(IDatasetV2 dataset)
36+
{
37+
throw new NotImplementedException("");
3338
}
3439
}
3540
}

src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters
1111
/// </summary>
1212
public class TensorLikeDataAdapter : IDataAdapter
1313
{
14+
TensorLikeDataAdapterArgs args;
15+
int _size;
16+
int _batch_size;
17+
int num_samples;
18+
int num_full_batches;
19+
1420
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
1521
{
16-
tf.data.Dataset.range(5);
22+
this.args = args;
23+
_process_tensorlike();
24+
num_samples = args.X.shape[0];
25+
var batch_size = args.BatchSize;
26+
_batch_size = batch_size;
27+
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f)));
28+
num_full_batches = num_samples / batch_size;
29+
var _partial_batch_size = num_samples % batch_size;
30+
31+
var indices_dataset = tf.data.Dataset.range(1);
32+
indices_dataset = indices_dataset.repeat();
33+
indices_dataset = indices_dataset.map(permutation).prefetch(1);
34+
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
35+
}
36+
37+
Tensor permutation(Tensor tensor)
38+
{
39+
var indices = math_ops.range(num_samples, dtype: dtypes.int64);
40+
if (args.Shuffle)
41+
indices = random_ops.random_shuffle(indices);
42+
return indices;
43+
}
44+
45+
/// <summary>
46+
/// Convert a Tensor of indices into a dataset of batched indices.
47+
/// </summary>
48+
/// <param name="tensor"></param>
49+
/// <returns></returns>
50+
IDatasetV2 slice_batch_indices(Tensor indices)
51+
{
52+
var num_in_full_batch = num_full_batches * _batch_size;
53+
var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch });
54+
first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size });
55+
var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices);
56+
57+
return flat_dataset;
58+
}
59+
60+
void slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y)
61+
{
62+
var dataset = tf.data.Dataset.from_tensor(x, y);
1763
}
1864

1965
public bool CanHandle(Tensor x, Tensor y = null)
2066
{
2167
throw new NotImplementedException();
2268
}
69+
70+
void _process_tensorlike()
71+
{
72+
}
2373
}
2474
}

0 commit comments

Comments
 (0)