@@ -19,6 +19,7 @@ limitations under the License.
1919using System . IO ;
2020using System . Linq ;
2121using static Tensorflow . SaverDef . Types ;
22+ using static Tensorflow . Binding ;
2223
2324namespace Tensorflow
2425{
@@ -144,5 +145,54 @@ private static string _prefix_to_checkpoint_path(string prefix, CheckpointFormat
144145 return prefix + ".index" ;
145146 return prefix ;
146147 }
148+
149+ /// <summary>
150+ /// Finds the filename of latest saved checkpoint file.
151+ /// </summary>
152+ /// <param name="checkpoint_dir"></param>
153+ /// <param name="latest_filename"></param>
154+ /// <returns></returns>
155+ public static string latest_checkpoint ( string checkpoint_dir , string latest_filename = null )
156+ {
157+ // Pick the latest checkpoint based on checkpoint state.
158+ var ckpt = get_checkpoint_state ( checkpoint_dir , latest_filename ) ;
159+ if ( ckpt != null && ! string . IsNullOrEmpty ( ckpt . ModelCheckpointPath ) )
160+ {
161+ // Look for either a V2 path or a V1 path, with priority for V2.
162+ var v2_path = _prefix_to_checkpoint_path ( ckpt . ModelCheckpointPath , CheckpointFormatVersion . V2 ) ;
163+ var v1_path = _prefix_to_checkpoint_path ( ckpt . ModelCheckpointPath , CheckpointFormatVersion . V1 ) ;
164+ if ( File . Exists ( v2_path ) || File . Exists ( v1_path ) )
165+ return ckpt . ModelCheckpointPath ;
166+ else
167+ throw new ValueError ( $ "Couldn't match files for checkpoint { ckpt . ModelCheckpointPath } ") ;
168+ }
169+ return null ;
170+ }
171+
172+ public static CheckpointState get_checkpoint_state ( string checkpoint_dir , string latest_filename = null )
173+ {
174+ var coord_checkpoint_filename = _GetCheckpointFilename ( checkpoint_dir , latest_filename ) ;
175+ if ( File . Exists ( coord_checkpoint_filename ) )
176+ {
177+ var file_content = File . ReadAllBytes ( coord_checkpoint_filename ) ;
178+ var ckpt = CheckpointState . Parser . ParseFrom ( file_content ) ;
179+ if ( string . IsNullOrEmpty ( ckpt . ModelCheckpointPath ) )
180+ throw new ValueError ( $ "Invalid checkpoint state loaded from { checkpoint_dir } ") ;
181+ // For relative model_checkpoint_path and all_model_checkpoint_paths,
182+ // prepend checkpoint_dir.
183+ if ( ! Path . IsPathRooted ( ckpt . ModelCheckpointPath ) )
184+ ckpt . ModelCheckpointPath = Path . Combine ( checkpoint_dir , ckpt . ModelCheckpointPath ) ;
185+ foreach ( var i in range ( len ( ckpt . AllModelCheckpointPaths ) ) )
186+ {
187+ var p = ckpt . AllModelCheckpointPaths [ i ] ;
188+ if ( ! Path . IsPathRooted ( p ) )
189+ ckpt . AllModelCheckpointPaths [ i ] = Path . Combine ( checkpoint_dir , p ) ;
190+ }
191+
192+ return ckpt ;
193+ }
194+
195+ return null ;
196+ }
147197 }
148198}
0 commit comments