Skip to content

Commit 0926654

Browse files
committed
add EstimatorV2 and IEstimator.
1 parent bbed4fb commit 0926654

11 files changed

Lines changed: 180 additions & 3 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.io.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Tensorflow
66
{
7-
public partial class tf
7+
public static partial class tf
88
{
99
public static Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);
1010

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Data
6+
{
7+
public class DatasetV1 : DatasetV2
8+
{
9+
}
10+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Data
6+
{
7+
/// <summary>
8+
/// Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API.
9+
/// </summary>
10+
public class DatasetV1Adapter : DatasetV1
11+
{
12+
}
13+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Data
6+
{
7+
/// <summary>
8+
/// Represents a potentially large set of elements.
9+
///
10+
/// A `Dataset` can be used to represent an input pipeline as a
11+
/// collection of elements (nested structures of tensors) and a "logical
12+
/// plan" of transformations that act on those elements.
13+
///
14+
/// tensorflow\python\data\ops\dataset_ops.py
15+
/// </summary>
16+
public class DatasetV2
17+
{
18+
}
19+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Data;
5+
6+
namespace Tensorflow.Estimator
7+
{
8+
/// <summary>
9+
/// Estimator class to train and evaluate TensorFlow models.
10+
/// <see cref="tensorflow_estimator\python\estimator\estimator.py"/>
11+
/// </summary>
12+
public class EstimatorV2 : IEstimator
13+
{
14+
public EstimatorV2(string model_dir = null)
15+
{
16+
17+
}
18+
19+
/// <summary>
20+
/// Calls the input function.
21+
/// </summary>
22+
/// <param name="mode"></param>
23+
public void call_input_fn(string mode = null)
24+
{
25+
26+
}
27+
28+
public void train_model_default(Func<string, string, HyperParams, bool, DatasetV1Adapter> input_fn)
29+
{
30+
31+
}
32+
33+
public void get_features_and_labels_from_input_fn()
34+
{
35+
36+
}
37+
}
38+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Estimator
6+
{
7+
public class HyperParams
8+
{
9+
public string data_dir { get; set; }
10+
public string result_dir { get; set; }
11+
public string model_dir { get; set; }
12+
public string eval_dir { get; set; }
13+
14+
public int dim { get; set; } = 300;
15+
public float dropout { get; set; } = 0.5f;
16+
public int num_oov_buckets { get; set; } = 1;
17+
public int epochs { get; set; } = 25;
18+
public int batch_size { get; set; } = 20;
19+
public int buffer { get; set; } = 15000;
20+
public int lstm_size { get; set; } = 100;
21+
22+
public string words { get; set; }
23+
public string chars { get; set; }
24+
public string tags { get; set; }
25+
public string glove { get; set; }
26+
}
27+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Estimator
6+
{
7+
public interface IEstimator
8+
{
9+
10+
}
11+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
### TensorFlow Estimator
2+
3+
TensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model.
4+
5+
Guide: <https://www.tensorflow.org/guide/estimators>
6+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Estimator
6+
{
7+
/// <summary>
8+
/// The executor to run `Estimator` training and evaluation.
9+
/// <see cref="tensorflow_estimator\python\estimator\training.py"/>
10+
/// </summary>
11+
public class TrainingExecutor : Python
12+
{
13+
private IEstimator _estimator;
14+
public TrainingExecutor(IEstimator estimator)
15+
{
16+
_estimator = estimator;
17+
}
18+
}
19+
}

test/TensorFlowNET.Examples/Text/NER/BiLstmCrfNer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples
1313
/// </summary>
1414
public class BiLstmCrfNer : IExample
1515
{
16-
public int Priority => 13;
16+
public int Priority => 101;
1717

1818
public bool Enabled { get; set; } = true;
1919
public bool ImportGraph { get; set; } = false;
@@ -24,7 +24,7 @@ public class BiLstmCrfNer : IExample
2424
public bool Run()
2525
{
2626
PrepareData();
27-
return true;
27+
return false;
2828
}
2929

3030
public void PrepareData()

0 commit comments

Comments
 (0)