1- using System ;
1+ using NumSharp ;
2+ using System ;
23using System . Collections . Generic ;
4+ using System . Linq ;
35using System . Text ;
6+ using static Tensorflow . Python ;
47
58namespace Tensorflow
69{
@@ -23,7 +26,181 @@ public GraphDef convert_variables_to_constants(Session sess,
2326 {
2427 // This graph only includes the nodes needed to evaluate the output nodes, and
2528 // removes unneeded nodes like those involved in saving and assignment.
26- throw new NotImplementedException ( "" ) ;
29+ var inference_graph = extract_sub_graph ( input_graph_def , output_node_names ) ;
30+
31+ // Identify the ops in the graph.
32+ var map_name_to_node = new Dictionary < string , NodeDef > ( ) ;
33+ inference_graph . Node . Select ( x => map_name_to_node [ x . Name ] = x ) . ToArray ( ) ;
34+
35+ // Get list of variables.
36+ var variable_names = new List < string > ( ) ;
37+ var variable_dict_names = new List < string > ( ) ;
38+
39+ foreach ( var node in inference_graph . Node )
40+ {
41+ if ( new string [ ] { "Variable" , "VariableV2" , "VarHandleOp" } . Contains ( node . Op ) )
42+ {
43+ var variable_name = node . Name ;
44+
45+ variable_dict_names . Add ( variable_name ) ;
46+ if ( node . Op == "VarHandleOp" )
47+ variable_names . Add ( variable_name + "/Read/ReadVariableOp:0" ) ;
48+ else
49+ variable_names . Add ( variable_name + ":0" ) ;
50+ }
51+ else if ( new string [ ] { "ReadVariableOp" , "ResourceGather" } . Contains ( node . Op ) )
52+ {
53+ // There can be one or more Identity ops in between the ReadVariableOp and
54+ // VarHandleOp. Store the Identity ops with the associated dtypes.
55+ var source_op_name = get_input_name ( node ) ;
56+ while ( map_name_to_node [ source_op_name ] . Op == "Identity" )
57+ {
58+ throw new NotImplementedException ( "map_name_to_node[source_op_name].Op" ) ;
59+ /*resource_identity_types[source_op_name] = node.attr["dtype"];
60+ source_op_name = get_input_name(map_name_to_node[source_op_name]);*/
61+ }
62+ }
63+ }
64+
65+ // Gets map of variables and the associated data.
66+ NDArray returned_variables = null ;
67+ if ( variable_names != null )
68+ returned_variables = sess . run ( variable_names ) ;
69+
70+ var variables_data_map = new Dictionary < string , NDArray > ( ) ;
71+ foreach ( var ( i , name ) in enumerate ( variable_dict_names ) )
72+ variables_data_map [ name ] = returned_variables [ i ] ;
73+ print ( $ "Froze { len ( returned_variables ) } variables.") ;
74+
75+ // Reconstruct the graph with constants in place of variables.
76+ var output_graph_def = new GraphDef ( ) ;
77+ int how_many_converted = 0 ;
78+ foreach ( var input_node in inference_graph . Node )
79+ {
80+ var output_node = new NodeDef ( ) ;
81+ if ( variables_data_map . ContainsKey ( input_node . Name ) )
82+ {
83+ var data = variables_data_map [ input_node . Name ] ;
84+ output_node = create_const_op ( input_node . Name , input_node . Attr [ "dtype" ] ,
85+ data , data . shape ) ;
86+ how_many_converted += 1 ;
87+ }
88+ // else if (resource_identity_types.ContainsKey(input_node.Name))
89+ else if ( input_node . Op == "ReadVariableOp" )
90+ {
91+ output_node . Op = "Identity" ;
92+ output_node . Name = input_node . Name ;
93+ output_node . Input . AddRange ( new [ ] { input_node . Input [ 0 ] } ) ;
94+ output_node . Attr [ "T" ] = input_node . Attr [ "dtype" ] ;
95+ }
96+ else if ( input_node . Op == "ResourceGather" )
97+ {
98+
99+ }
100+ else
101+ {
102+ output_node . MergeFrom ( input_node ) ;
103+ }
104+
105+ output_graph_def . Node . AddRange ( new [ ] { output_node } ) ;
106+ }
107+
108+ output_graph_def . Library = inference_graph . Library ;
109+ print ( $ "Converted { how_many_converted } variables to const ops.") ;
110+ return output_graph_def ;
111+ }
112+
113+ private NodeDef create_const_op ( string node_name , AttrValue dtype , NDArray data , int [ ] data_shape = null )
114+ {
115+ var output_node = new NodeDef
116+ {
117+ Op = "Const" ,
118+ Name = node_name
119+ } ;
120+ output_node . Attr [ "dtype" ] = dtype ;
121+ output_node . Attr [ "value" ] = new AttrValue ( )
122+ {
123+ Tensor = tensor_util . make_tensor_proto (
124+ data , dtype : dtype . Type . as_tf_dtype ( ) , shape : data_shape )
125+ } ;
126+
127+ return output_node ;
128+ }
129+
130+ /// <summary>
131+ /// Gets the name of the first input. Errors if suffix is not :0.
132+ /// </summary>
133+ /// <param name="node"></param>
134+ /// <returns></returns>
135+ private string get_input_name ( NodeDef node )
136+ {
137+ var details = node . Input [ 0 ] . Split ( ':' ) ;
138+ if ( details . Length == 1 || int . Parse ( details [ 1 ] ) == 0 )
139+ return details [ 0 ] ;
140+ // While it is valid for input tensors to have a suffix that is not :0, this
141+ // method is used to find the associated ops, not tensors, and therefore it
142+ // is not valid.
143+ throw new ValueError ( $ "Tensor name '{ node . Input [ 0 ] } ' is invalid.") ;
144+ }
145+
146+
147+ private GraphDef extract_sub_graph ( GraphDef graph_def , string [ ] dest_nodes )
148+ {
149+ var ( name_to_input_name , name_to_node , name_to_seq_num ) = _extract_graph_summary (
150+ graph_def ) ;
151+
152+ var nodes_to_keep = _bfs_for_reachable_nodes ( dest_nodes , name_to_input_name ) ;
153+ var nodes_to_keep_list = nodes_to_keep . OrderBy ( n => name_to_seq_num [ n ] ) . ToArray ( ) ;
154+ // Now construct the output GraphDef
155+ var output = new GraphDef ( ) ;
156+ foreach ( var n in nodes_to_keep_list )
157+ output . Node . Add ( name_to_node [ n ] ) ; // need deep clone?
158+ output . Library = graph_def . Library ;
159+ output . Versions = graph_def . Versions ;
160+
161+ return output ;
162+ }
163+
164+ private string [ ] _bfs_for_reachable_nodes ( string [ ] target_nodes , Dictionary < string , string [ ] > name_to_input_name )
165+ {
166+ var nodes_to_keep = new List < string > ( ) ;
167+ var next_to_visit = target_nodes . Select ( x => x ) . ToList ( ) ;
168+ while ( next_to_visit . Count > 0 )
169+ {
170+ var node = next_to_visit [ 0 ] ;
171+ next_to_visit . RemoveAt ( 0 ) ;
172+ if ( nodes_to_keep . Contains ( node ) )
173+ continue ;
174+ nodes_to_keep . Add ( node ) ;
175+ if ( name_to_input_name . Keys . Contains ( node ) )
176+ next_to_visit . AddRange ( name_to_input_name [ node ] ) ;
177+ }
178+
179+ return nodes_to_keep . ToArray ( ) ;
180+ }
181+
182+ private ( Dictionary < string , string [ ] > , Dictionary < string , NodeDef > , Dictionary < string , int > ) _extract_graph_summary ( GraphDef graph_def )
183+ {
184+ var name_to_input_name = new Dictionary < string , string [ ] > ( ) ;
185+ var name_to_node = new Dictionary < string , NodeDef > ( ) ;
186+ var name_to_seq_num = new Dictionary < string , int > ( ) ;
187+
188+ int seq = 0 ;
189+ foreach ( var node in graph_def . Node )
190+ {
191+ var n = _node_name ( node . Name ) ;
192+ name_to_node [ n ] = node ;
193+ name_to_input_name [ n ] = node . Input . Select ( x => _node_name ( x ) ) . ToArray ( ) ;
194+ name_to_seq_num [ n ] = seq ;
195+ seq ++ ;
196+ }
197+
198+ return ( name_to_input_name , name_to_node , name_to_seq_num ) ;
199+ }
200+
201+ private string _node_name ( string n )
202+ {
203+ return n . StartsWith ( "^" ) ? n . Substring ( 1 ) : n . Split ( ':' ) [ 0 ] ;
27204 }
28205
29206 private string get_input_name ( string node )
0 commit comments