Skip to content

Commit 808b95a

Browse files
committed
create image process examples' folder.
1 parent 0b1a8d3 commit 808b95a

15 files changed

Lines changed: 368 additions & 55 deletions

File tree

data/lstm_crf_ner.zip

769 Bytes
Binary file not shown.

graph/lstm_crf_ner.meta

1010 KB
Binary file not shown.
Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,92 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
34
using System.Text;
45

56
namespace Tensorflow.Estimator
67
{
78
public class HyperParams
89
{
9-
public string data_dir { get; set; }
10-
public string result_dir { get; set; }
11-
public string model_dir { get; set; }
12-
public string eval_dir { get; set; }
10+
/// <summary>
11+
/// root dir
12+
/// </summary>
13+
public string data_root_dir { get; set; }
14+
15+
/// <summary>
16+
/// results dir
17+
/// </summary>
18+
public string result_dir { get; set; } = "results";
19+
20+
/// <summary>
21+
/// model dir
22+
/// </summary>
23+
public string model_dir { get; set; } = "model";
24+
25+
public string eval_dir { get; set; } = "eval";
26+
27+
public string test_dir { get; set; } = "test";
1328

1429
public int dim { get; set; } = 300;
1530
public float dropout { get; set; } = 0.5f;
1631
public int num_oov_buckets { get; set; } = 1;
1732
public int epochs { get; set; } = 25;
33+
public int epoch_no_imprv { get; set; } = 3;
1834
public int batch_size { get; set; } = 20;
1935
public int buffer { get; set; } = 15000;
2036
public int lstm_size { get; set; } = 100;
37+
public string lr_method { get; set; } = "adam";
38+
public float lr { get; set; } = 0.001f;
39+
public float lr_decay { get; set; } = 0.9f;
40+
41+
/// <summary>
42+
/// lstm on chars
43+
/// </summary>
44+
public int hidden_size_char { get; set; } = 100;
45+
46+
/// <summary>
47+
/// lstm on word embeddings
48+
/// </summary>
49+
public int hidden_size_lstm { get; set; } = 300;
50+
51+
/// <summary>
52+
/// is clipping
53+
/// </summary>
54+
public bool clip { get; set; } = false;
55+
56+
public string filepath_dev { get; set; }
57+
public string filepath_test { get; set; }
58+
public string filepath_train { get; set; }
59+
60+
public string filepath_words { get; set; }
61+
public string filepath_chars { get; set; }
62+
public string filepath_tags { get; set; }
63+
public string filepath_glove { get; set; }
64+
65+
public HyperParams(string dataDir)
66+
{
67+
data_root_dir = dataDir;
68+
69+
if (string.IsNullOrEmpty(data_root_dir))
70+
throw new ValueError("Please specifiy the root data directory");
71+
72+
if (!Directory.Exists(data_root_dir))
73+
Directory.CreateDirectory(data_root_dir);
74+
75+
result_dir = Path.Combine(data_root_dir, result_dir);
76+
if (!Directory.Exists(result_dir))
77+
Directory.CreateDirectory(result_dir);
78+
79+
model_dir = Path.Combine(result_dir, model_dir);
80+
if (!Directory.Exists(model_dir))
81+
Directory.CreateDirectory(model_dir);
82+
83+
test_dir = Path.Combine(result_dir, test_dir);
84+
if (!Directory.Exists(test_dir))
85+
Directory.CreateDirectory(test_dir);
2186

22-
public string words { get; set; }
23-
public string chars { get; set; }
24-
public string tags { get; set; }
25-
public string glove { get; set; }
87+
eval_dir = Path.Combine(result_dir, eval_dir);
88+
if (!Directory.Exists(eval_dir))
89+
Directory.CreateDirectory(eval_dir);
90+
}
2691
}
2792
}

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,18 @@ public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_sco
101101
switch (col.Key)
102102
{
103103
case "cond_context":
104-
var proto = CondContextDef.Parser.ParseFrom(value);
105-
var condContext = new CondContext().from_proto(proto, import_scope);
106-
graph.add_to_collection(col.Key, condContext);
104+
{
105+
var proto = CondContextDef.Parser.ParseFrom(value);
106+
var condContext = new CondContext().from_proto(proto, import_scope);
107+
graph.add_to_collection(col.Key, condContext);
108+
}
109+
break;
110+
case "while_context":
111+
{
112+
var proto = WhileContextDef.Parser.ParseFrom(value);
113+
var whileContext = new WhileContext().from_proto(proto, import_scope);
114+
graph.add_to_collection(col.Key, whileContext);
115+
}
107116
break;
108117
default:
109118
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ protected ControlFlowContext from_control_flow_context_def(ControlFlowContextDef
198198
{
199199
case CtxtOneofCase.CondCtxt:
200200
return new CondContext().from_proto(context_def.CondCtxt, import_scope: import_scope);
201+
case CtxtOneofCase.WhileCtxt:
202+
return new WhileContext().from_proto(context_def.WhileCtxt, import_scope: import_scope);
201203
}
202204

203205
throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");

src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,70 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Tensorflow.Operations.ControlFlows;
5+
using static Tensorflow.Python;
56

67
namespace Tensorflow.Operations
78
{
9+
/// <summary>
10+
/// Creates a `WhileContext`.
11+
/// </summary>
812
public class WhileContext : ControlFlowContext
913
{
10-
private bool _back_prop=true;
14+
bool _back_prop=true;
15+
GradLoopState _grad_state =null;
16+
Tensor _maximum_iterations;
17+
int _parallel_iterations;
18+
bool _swap_memory;
19+
Tensor _pivot_for_pred;
20+
Tensor _pivot_for_body;
21+
Tensor[] _loop_exits;
22+
Tensor[] _loop_enters;
1123

12-
private GradLoopState _grad_state =null;
24+
public WhileContext(int parallel_iterations = 10,
25+
bool back_prop = true,
26+
bool swap_memory = false,
27+
string name = "while_context",
28+
GradLoopState grad_state = null,
29+
WhileContextDef context_def = null,
30+
string import_scope = null)
31+
{
32+
if (context_def != null)
33+
{
34+
_init_from_proto(context_def, import_scope: import_scope);
35+
}
36+
else
37+
{
38+
39+
}
40+
41+
_grad_state = grad_state;
42+
}
43+
44+
private void _init_from_proto(WhileContextDef context_def, string import_scope = null)
45+
{
46+
var g = ops.get_default_graph();
47+
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
48+
if (!string.IsNullOrEmpty(context_def.MaximumIterationsName))
49+
_maximum_iterations = g.as_graph_element(ops.prepend_name_scope(context_def.MaximumIterationsName, import_scope)) as Tensor;
50+
_parallel_iterations = context_def.ParallelIterations;
51+
_back_prop = context_def.BackProp;
52+
_swap_memory = context_def.SwapMemory;
53+
_pivot_for_pred = g.as_graph_element(ops.prepend_name_scope(context_def.PivotForPredName, import_scope)) as Tensor;
54+
// We use this node to control constants created by the body lambda.
55+
_pivot_for_body = g.as_graph_element(ops.prepend_name_scope(context_def.PivotForBodyName, import_scope)) as Tensor;
56+
// The boolean tensor for loop termination condition.
57+
_pivot = g.as_graph_element(ops.prepend_name_scope(context_def.PivotName, import_scope)) as Tensor;
58+
// The list of exit tensors for loop variables.
59+
_loop_exits = new Tensor[context_def.LoopExitNames.Count];
60+
foreach (var (i, exit_name) in enumerate(context_def.LoopExitNames))
61+
_loop_exits[i] = g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) as Tensor;
62+
// The list of enter tensors for loop variables.
63+
_loop_enters = new Tensor[context_def.LoopEnterNames.Count];
64+
foreach (var (i, enter_name) in enumerate(context_def.LoopEnterNames))
65+
_loop_enters[i] = g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) as Tensor;
66+
67+
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
68+
}
1369

1470
public override WhileContext GetWhileContext()
1571
{
@@ -21,9 +77,15 @@ public override WhileContext GetWhileContext()
2177

2278
public override bool back_prop => _back_prop;
2379

24-
public static WhileContext from_proto(object proto)
80+
public WhileContext from_proto(WhileContextDef proto, string import_scope)
2581
{
26-
throw new NotImplementedException();
82+
var ret = new WhileContext(context_def: proto, import_scope: import_scope);
83+
84+
ret.Enter();
85+
foreach (var nested_def in proto.NestedContexts)
86+
from_control_flow_context_def(nested_def, import_scope: import_scope);
87+
ret.Exit();
88+
return ret;
2789
}
2890

2991
public object to_proto()

src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
120120
case List<CondContext> values:
121121
foreach (var element in values) ;
122122
break;
123+
case List<WhileContext> values:
124+
foreach (var element in values) ;
125+
break;
123126
default:
124127
throw new NotImplementedException("_build_internal.check_collection_list");
125128
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using TensorFlowNET.Examples.Utility;
5+
6+
namespace TensorFlowNET.Examples.ImageProcess
7+
{
8+
/// <summary>
9+
/// This example removes the background from an input image.
10+
///
11+
/// https://github.com/susheelsk/image-background-removal
12+
/// </summary>
13+
public class ImageBackgroundRemoval : IExample
14+
{
15+
public int Priority => 15;
16+
17+
public bool Enabled { get; set; } = true;
18+
public bool ImportGraph { get; set; } = true;
19+
20+
public string Name => "Image Background Removal";
21+
22+
string modelDir = "deeplabv3";
23+
24+
public bool Run()
25+
{
26+
return false;
27+
}
28+
29+
public void PrepareData()
30+
{
31+
// get model file
32+
string url = "http://download.tensorflow.org/models/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz";
33+
Web.Download(url, modelDir, "deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz");
34+
}
35+
}
36+
}

test/TensorFlowNET.Examples/ImageRecognitionInception.cs renamed to test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs

File renamed without changes.

test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs renamed to test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs

File renamed without changes.

0 commit comments

Comments
 (0)