@@ -11,14 +11,64 @@ namespace Tensorflow.Keras.Engine.DataAdapters
1111 /// </summary>
1212 public class TensorLikeDataAdapter : IDataAdapter
1313 {
14+ TensorLikeDataAdapterArgs args ;
15+ int _size ;
16+ int _batch_size ;
17+ int num_samples ;
18+ int num_full_batches ;
19+
1420 public TensorLikeDataAdapter ( TensorLikeDataAdapterArgs args )
1521 {
16- tf . data . Dataset . range ( 5 ) ;
22+ this . args = args ;
23+ _process_tensorlike ( ) ;
24+ num_samples = args . X . shape [ 0 ] ;
25+ var batch_size = args . BatchSize ;
26+ _batch_size = batch_size ;
27+ _size = Convert . ToInt32 ( Math . Ceiling ( num_samples / ( batch_size + 0f ) ) ) ;
28+ num_full_batches = num_samples / batch_size ;
29+ var _partial_batch_size = num_samples % batch_size ;
30+
31+ var indices_dataset = tf . data . Dataset . range ( 1 ) ;
32+ indices_dataset = indices_dataset . repeat ( ) ;
33+ indices_dataset = indices_dataset . map ( permutation ) . prefetch ( 1 ) ;
34+ indices_dataset = indices_dataset . flat_map ( slice_batch_indices ) ;
35+ }
36+
37+ Tensor permutation ( Tensor tensor )
38+ {
39+ var indices = math_ops . range ( num_samples , dtype : dtypes . int64 ) ;
40+ if ( args . Shuffle )
41+ indices = random_ops . random_shuffle ( indices ) ;
42+ return indices ;
43+ }
44+
45+ /// <summary>
46+ /// Convert a Tensor of indices into a dataset of batched indices.
47+ /// </summary>
48+ /// <param name="tensor"></param>
49+ /// <returns></returns>
50+ IDatasetV2 slice_batch_indices ( Tensor indices )
51+ {
52+ var num_in_full_batch = num_full_batches * _batch_size ;
53+ var first_k_indices = array_ops . slice ( indices , new int [ ] { 0 } , new int [ ] { num_in_full_batch } ) ;
54+ first_k_indices = array_ops . reshape ( first_k_indices , new int [ ] { num_full_batches , _batch_size } ) ;
55+ var flat_dataset = tf . data . Dataset . from_tensor_slices ( first_k_indices ) ;
56+
57+ return flat_dataset ;
58+ }
59+
60+ void slice_inputs ( IDatasetV2 indices_dataset , Tensor x , Tensor y )
61+ {
62+ var dataset = tf . data . Dataset . from_tensor ( x , y ) ;
1763 }
1864
1965 public bool CanHandle ( Tensor x , Tensor y = null )
2066 {
2167 throw new NotImplementedException ( ) ;
2268 }
69+
70+ void _process_tensorlike ( )
71+ {
72+ }
2373 }
2474}
0 commit comments