Skip to content

Commit 3a220fa

Browse files
committed
tf.data framework.
1 parent 388b645 commit 3a220fa

32 files changed

Lines changed: 993 additions & 70 deletions

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace Tensorflow
4343
/// </summary>
4444
public partial class c_api
4545
{
46-
public const string TensorFlowLibName = "tensorflow";
46+
public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll";
4747

4848
public static string StringPiece(IntPtr handle)
4949
{

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ public static float time()
265265
yield return (i, values[i]);
266266
}
267267

268+
public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0)
269+
{
270+
int i = 0;
271+
foreach(var val in values)
272+
{
273+
if (i < start)
274+
continue;
275+
yield return (i, val);
276+
}
277+
}
278+
268279
[DebuggerStepThrough]
269280
public static Dictionary<string, object> ConvertToDict(object dyn)
270281
{
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Framework.Models;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow
9+
{
10+
/// <summary>
11+
/// A `Dataset` that batches contiguous elements from its input.
12+
/// </summary>
13+
public class BatchDataset : UnaryDataset
14+
{
15+
Tensor _batch_size;
16+
Tensor _drop_remainder;
17+
18+
public BatchDataset(IDatasetV2 input_dataset, int batch_size, bool drop_remainder = false) :
19+
base(input_dataset)
20+
{
21+
_input_dataset = input_dataset;
22+
_batch_size = tf.convert_to_tensor(batch_size, dtype: TF_DataType.TF_INT64, name: "batch_size");
23+
_drop_remainder = tf.convert_to_tensor(drop_remainder, dtype: TF_DataType.TF_BOOL, name: "drop_remainder");
24+
25+
if (drop_remainder)
26+
{
27+
throw new NotImplementedException("");
28+
}
29+
else
30+
{
31+
_structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray();
32+
}
33+
34+
variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor,
35+
_batch_size,
36+
_drop_remainder,
37+
output_types,
38+
output_shapes);
39+
}
40+
}
41+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Framework.Models;
6+
7+
namespace Tensorflow
8+
{
9+
public class DatasetSource : DatasetV2
10+
{
11+
protected Tensor[] _tensors;
12+
13+
public DatasetSource()
14+
{
15+
16+
}
17+
}
18+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using System;
2+
using System.Collections;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow.Framework.Models;
7+
8+
namespace Tensorflow
9+
{
10+
/// <summary>
11+
/// Abstract class representing a dataset with no inputs.
12+
/// </summary>
13+
public class DatasetV2 : IDatasetV2
14+
{
15+
protected dataset_ops ops = new dataset_ops();
16+
public Tensor variant_tensor { get; set; }
17+
18+
public TensorSpec[] _structure { get; set; }
19+
20+
public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray();
21+
22+
public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray();
23+
24+
public TensorSpec[] element_spec => _structure;
25+
26+
public IDatasetV2 take(int count = -1)
27+
=> new TakeDataset(this, count: count);
28+
29+
public IDatasetV2 batch(int batch_size, bool drop_remainder = false)
30+
=> new BatchDataset(this, batch_size, drop_remainder: drop_remainder);
31+
32+
public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
33+
=> new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period);
34+
35+
public IDatasetV2 repeat(int count = -1)
36+
=> new RepeatDataset(this, count: count);
37+
38+
public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
39+
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
40+
41+
public override string ToString()
42+
=> $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})";
43+
44+
public IEnumerator<(Tensor, Tensor)> GetEnumerator()
45+
{
46+
throw new NotImplementedException();
47+
}
48+
49+
IEnumerator IEnumerable.GetEnumerator()
50+
{
51+
return this.GetEnumerator();
52+
}
53+
}
54+
}
Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Framework.Models;
45

56
namespace Tensorflow
67
{
7-
public interface IDatasetV2
8+
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
89
{
10+
Tensor variant_tensor { get; set; }
911

12+
TensorShape[] output_shapes { get; }
13+
14+
TF_DataType[] output_types { get; }
15+
16+
TensorSpec[] element_spec { get; }
17+
18+
TensorSpec[] _structure { get; set; }
19+
20+
/// <summary>
21+
///
22+
/// </summary>
23+
/// <param name="count"></param>
24+
/// <returns></returns>
25+
IDatasetV2 repeat(int count = -1);
26+
27+
IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true);
28+
29+
IDatasetV2 batch(int batch_size, bool drop_remainder = false);
30+
31+
IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null);
32+
33+
IDatasetV2 take(int count);
1034
}
1135
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
/// <summary>
9+
/// Creates a `Dataset` that prefetches elements from this dataset.
10+
/// </summary>
11+
public class PrefetchDataset : UnaryUnchangedStructureDataset
12+
{
13+
Tensor _buffer_size;
14+
15+
public PrefetchDataset(IDatasetV2 input_dataset,
16+
long buffer_size = -1,
17+
int? slack_period = null) :
18+
base(input_dataset)
19+
{
20+
_buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size");
21+
22+
variant_tensor = ops.prefetch_dataset(input_dataset.variant_tensor,
23+
_buffer_size,
24+
input_dataset.output_types,
25+
input_dataset.output_shapes,
26+
slack_period: slack_period);
27+
}
28+
}
29+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
/// <summary>
8+
/// A `Dataset` that repeats its input several times.
9+
/// </summary>
10+
public class RepeatDataset : UnaryUnchangedStructureDataset
11+
{
12+
Tensor _count;
13+
14+
public RepeatDataset(IDatasetV2 input_dataset, int count = -1) :
15+
base(input_dataset)
16+
{
17+
_count = constant_op.constant(count, dtype: TF_DataType.TF_INT64, name: "count");
18+
variant_tensor = ops.repeat_dataset(input_dataset.variant_tensor,
19+
_count,
20+
input_dataset.output_types,
21+
input_dataset.output_shapes);
22+
}
23+
}
24+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
/// <summary>
9+
/// Randomly shuffles the elements of this dataset.
10+
/// </summary>
11+
public class ShuffleDataset : UnaryUnchangedStructureDataset
12+
{
13+
Tensor _buffer_size;
14+
Tensor _seed;
15+
Tensor _seed2;
16+
bool _reshuffle_each_iteration;
17+
18+
public ShuffleDataset(IDatasetV2 input_dataset,
19+
long buffer_size,
20+
int? seed = null,
21+
bool reshuffle_each_iteration = true) :
22+
base(input_dataset)
23+
{
24+
_buffer_size = tf.convert_to_tensor(buffer_size, dtype: TF_DataType.TF_INT64, name: "buffer_size");
25+
(_seed, _seed2) = random_seed.get_seed_tensor(seed);
26+
_reshuffle_each_iteration = reshuffle_each_iteration;
27+
var seed_generator = ops.dummy_seed_generator();
28+
if (tf.context.executing_eagerly())
29+
variant_tensor = ops.shuffle_dataset_v3(input_dataset.variant_tensor, _buffer_size,
30+
_seed, _seed2, seed_generator,
31+
output_types, output_shapes,
32+
reshuffle_each_iteration: _reshuffle_each_iteration);
33+
else
34+
throw new NotImplementedException("");
35+
}
36+
}
37+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow
7+
{
8+
public class TakeDataset : UnaryUnchangedStructureDataset
9+
{
10+
Tensor _count;
11+
12+
public TakeDataset(IDatasetV2 input_dataset, int count) :
13+
base(input_dataset)
14+
{
15+
_count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count");
16+
variant_tensor = ops.take_dataset(input_dataset.variant_tensor, _count,
17+
output_types, output_shapes);
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)