Skip to content

Commit de3ecc8

Browse files
committed
create_estimator_and_inputs
1 parent 67eeab0 commit de3ecc8

6 files changed

Lines changed: 14 additions & 10 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.estimator.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ public partial class tensorflow
2626

2727
public class Estimator_Internal
2828
{
29-
public Estimator Estimator(RunConfig config)
30-
=> new Estimator(config: config);
29+
public Estimator Estimator(Action model_fn, RunConfig config)
30+
=> new Estimator(model_fn: model_fn, config: config);
3131

3232
public RunConfig RunConfig(string model_dir)
3333
=> new RunConfig(model_dir: model_dir);

src/TensorFlowNET.Core/Estimators/Estimator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class Estimator : IObjectLife
1818

1919
string _model_dir;
2020

21-
public Estimator(RunConfig config)
21+
public Estimator(Action model_fn, RunConfig config)
2222
{
2323
_config = config;
2424
_model_dir = _config.model_dir;

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>1.14.0</TargetTensorFlow>
8-
<Version>0.11.3</Version>
8+
<Version>0.11.4</Version>
99
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
1010
<Company>SciSharp STACK</Company>
1111
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -17,7 +17,7 @@
1717
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
1818
<Description>Google's TensorFlow full binding in .NET Standard.
1919
Docs: https://tensorflownet.readthedocs.io</Description>
20-
<AssemblyVersion>0.11.3.0</AssemblyVersion>
20+
<AssemblyVersion>0.11.4.0</AssemblyVersion>
2121
<PackageReleaseNotes>Changes since v0.10.0:
2222
1. Upgrade NumSharp to v0.20.
2323
2. Add DisposableObject class to manage object lifetime.
@@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
3030
9. MultiThread is safe.
3131
10. Support n-dim indexing for tensor.</PackageReleaseNotes>
3232
<LangVersion>7.3</LangVersion>
33-
<FileVersion>0.11.3.0</FileVersion>
33+
<FileVersion>0.11.4.0</FileVersion>
3434
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3535
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
3636
<SignAssembly>true</SignAssembly>

src/TensorFlowNET.Models/ObjectDetection/ModelLib.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,20 @@ public TrainAndEvalDict create_estimator_and_inputs(RunConfig run_config,
1818
int sample_1_of_n_eval_examples = 0,
1919
int sample_1_of_n_eval_on_train_examples = 1)
2020
{
21-
var estimator = tf.estimator.Estimator(config: run_config);
22-
2321
var config = ConfigUtil.get_configs_from_pipeline_file(pipeline_config_path);
2422
var eval_input_configs = config.EvalInputReader;
2523

2624
var eval_input_fns = new Action[eval_input_configs.Count];
25+
var eval_input_names = eval_input_configs.Select(eval_input_config => eval_input_config.Name).ToArray();
26+
Action model_fn = () => { };
27+
var estimator = tf.estimator.Estimator(model_fn: model_fn, config: run_config);
2728

2829
return new TrainAndEvalDict
2930
{
3031
estimator = estimator,
3132
train_steps = train_steps,
32-
eval_input_fns = eval_input_fns
33+
eval_input_fns = eval_input_fns,
34+
eval_input_names = eval_input_names
3335
};
3436
}
3537

test/TensorFlowNET.UnitTest/Estimators/RunConfigTest.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ public void test_default_property_values()
5353
Assert.IsNull(config.tf_random_seed);
5454
Assert.AreEqual(100, config.save_summary_steps);
5555
Assert.AreEqual(600, config.save_checkpoints_secs);
56-
Assert.IsNull(config.save_checkpoints_steps);
5756
Assert.AreEqual(5, config.keep_checkpoint_max);
5857
Assert.AreEqual(10000, config.keep_checkpoint_every_n_hours);
5958
Assert.IsNull(config.service);

test/TensorFlowNET.UnitTest/MultithreadingTests.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ void Core(int tid)
289289
[TestMethod]
290290
public void TF_GraphOperationByName_FromModel()
291291
{
292+
if (!Directory.Exists(modelPath))
293+
return;
294+
292295
MultiThreadedUnitTestExecuter.Run(8, Core);
293296

294297
//the core method

0 commit comments

Comments
 (0)