@@ -123,7 +123,7 @@ public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_sco
123123 /// <param name="strip_default_attrs"></param>
124124 /// <param name="meta_info_def"></param>
125125 /// <returns></returns>
126- public static MetaGraphDef export_scoped_meta_graph ( string filename = "" ,
126+ public static ( MetaGraphDef , Dictionary < string , RefVariable > ) export_scoped_meta_graph ( string filename = "" ,
127127 GraphDef graph_def = null ,
128128 bool as_text = false ,
129129 string unbound_inputs_col_name = "unbound_inputs" ,
@@ -138,7 +138,7 @@ public static MetaGraphDef export_scoped_meta_graph(string filename = "",
138138 var var_list = new Dictionary < string , RefVariable > ( ) ;
139139 var variables = graph . get_collection ( ops . GraphKeys . GLOBAL_VARIABLES ) ;
140140
141- foreach ( var v in variables as RefVariable [ ] )
141+ foreach ( var v in variables as List < RefVariable > )
142142 {
143143 var_list [ v . name ] = v ;
144144 }
@@ -151,15 +151,18 @@ public static MetaGraphDef export_scoped_meta_graph(string filename = "",
151151 saver_def : saver_def ,
152152 strip_default_attrs : strip_default_attrs ) ;
153153
154- throw new NotImplementedException ( "meta_graph.export_scoped_meta_graph" ) ;
154+ if ( ! string . IsNullOrEmpty ( filename ) )
155+ graph_io . write_graph ( scoped_meta_graph_def , "" , filename , as_text : as_text ) ;
156+
157+ return ( scoped_meta_graph_def , var_list ) ;
155158 }
156159
157160 private static bool _should_include_node ( )
158161 {
159162 return true ;
160163 }
161164
162- private static byte [ ] create_meta_graph_def ( MetaInfoDef meta_info_def = null ,
165+ private static MetaGraphDef create_meta_graph_def ( MetaInfoDef meta_info_def = null ,
163166 GraphDef graph_def = null ,
164167 string export_scope = "" ,
165168 string exclude_nodes = "" ,
@@ -168,7 +171,7 @@ private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null,
168171 bool strip_default_attrs = false )
169172 {
170173 // Sets graph to default graph if it's not passed in.
171- var graph = ops . get_default_graph ( ) ;
174+ var graph = ops . get_default_graph ( ) . as_default ( ) ;
172175 // Creates a MetaGraphDef proto.
173176 var meta_graph_def = new MetaGraphDef ( ) ;
174177 if ( meta_info_def == null )
@@ -186,10 +189,55 @@ private static byte[] create_meta_graph_def(MetaInfoDef meta_info_def = null,
186189 meta_graph_def . GraphDef = graph_def ;
187190
188191 // Fills in meta_info_def.stripped_op_list using the ops from graph_def.
189- if ( meta_graph_def . MetaInfoDef . StrippedOpList . Op . Count == 0 )
192+ if ( meta_graph_def . MetaInfoDef . StrippedOpList == null ||
193+ meta_graph_def . MetaInfoDef . StrippedOpList . Op . Count == 0 )
190194 meta_graph_def . MetaInfoDef . StrippedOpList = stripped_op_list_for_graph ( meta_graph_def . GraphDef ) ;
191195
192- throw new NotImplementedException ( "create_meta_graph_def" ) ;
196+ var clist = graph . get_all_collection_keys ( ) ;
197+ foreach ( var ctype in clist )
198+ {
199+ if ( clear_extraneous_savers )
200+ {
201+ throw new NotImplementedException ( "create_meta_graph_def clear_extraneous_savers" ) ;
202+ }
203+ else
204+ {
205+ add_collection_def ( meta_graph_def , ctype , graph ) ;
206+ }
207+ }
208+
209+ return meta_graph_def ;
210+ }
211+
212+ private static void add_collection_def ( MetaGraphDef meta_graph_def ,
213+ string key ,
214+ Graph graph = null ,
215+ string export_scope = "" )
216+ {
217+ if ( ! meta_graph_def . CollectionDef . ContainsKey ( key ) )
218+ meta_graph_def . CollectionDef [ key ] = new CollectionDef ( ) ;
219+ var col_def = meta_graph_def . CollectionDef [ key ] ;
220+
221+ switch ( graph . get_collection ( key ) )
222+ {
223+ case List < RefVariable > collection_list :
224+ col_def . BytesList = new Types . BytesList ( ) ;
225+ foreach ( var x in collection_list )
226+ {
227+ var proto = x . to_proto ( export_scope ) ;
228+ col_def . BytesList . Value . Add ( proto . ToByteString ( ) ) ;
229+ }
230+
231+ break ;
232+ case List < object > collection_list :
233+ col_def . NodeList = new Types . NodeList ( ) ;
234+ foreach ( var x in collection_list )
235+ if ( x is ITensorOrOperation x2 )
236+ col_def . NodeList . Value . Add ( ops . strip_name_scope ( x2 . name , export_scope ) ) ;
237+ break ;
238+ case List < Operation > collection_list :
239+ break ;
240+ }
193241 }
194242
195243 private static OpList stripped_op_list_for_graph ( GraphDef graph_def )
0 commit comments