@@ -211,9 +211,11 @@ public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_pref
211211
212212 string restore_device = string . IsNullOrEmpty ( options . experimental_io_device ) ? "cpu:0" : options . experimental_io_device ! ;
213213
214- // tf python has code `with ops.device(restore_device):` here.
215- tf . device ( restore_device ) ; // may be risky.
216- var restored_tensors = gen_ops . restore_v2 ( file_prefix , tensor_names . ToArray ( ) , slice_specs . ToArray ( ) , tensor_dtypes . ToArray ( ) ) ;
214+ Tensor [ ] restored_tensors = null ;
215+ tf_with ( ops . device ( restore_device ) , _ =>
216+ {
217+ restored_tensors = gen_ops . restore_v2 ( file_prefix , tensor_names . ToArray ( ) , slice_specs . ToArray ( ) , tensor_dtypes . ToArray ( ) ) ;
218+ } ) ;
217219
218220 Dictionary < string , IDictionary < string , Tensor > > restored_tensor_dict = new ( ) ;
219221 int idx = 0 ;
@@ -338,11 +340,14 @@ public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
338340 options = new CheckpointOptions ( ) ;
339341 }
340342
341- tf . device ( "CPU" ) ; // may be risky.
342- var sharded_suffix = array_ops . where ( gen_ops . regex_full_match ( file_prefix , tf . constant ( @"^s3://.*" ) ) ,
343+ Tensor tmp_checkpoint_prefix = null ;
344+ tf_with ( ops . device ( "CPU" ) , _ =>
345+ {
346+ var sharded_suffix = array_ops . where ( gen_ops . regex_full_match ( file_prefix , tf . constant ( @"^s3://.*" ) ) ,
343347 constant_op . constant ( ".part" ) , constant_op . constant ( "_temp/part" ) ) ;
344- var tmp_checkpoint_prefix = gen_ops . string_join ( new Tensor [ ] { file_prefix , sharded_suffix } ) ;
345- IDictionary < string , Tensor > registered_paths = _registered_savers . Keys . ToDictionary ( x => x , x => registered_saver_filename ( file_prefix , x ) ) ;
348+ tmp_checkpoint_prefix = gen_ops . string_join ( new Tensor [ ] { file_prefix , sharded_suffix } ) ;
349+ IDictionary < string , Tensor > registered_paths = _registered_savers . Keys . ToDictionary ( x => x , x => registered_saver_filename ( file_prefix , x ) ) ;
350+ } ) ;
346351
347352 Operation save_fn ( )
348353 {
@@ -364,16 +369,24 @@ Operation save_fn()
364369 var saver = pair . Value ;
365370 last_device = device ;
366371 // skip the extra process of device name because of lack of API.
367- tf . device ( device ) ;
368- var shard_prefix = sharded_filename ( tmp_checkpoint_prefix , shard , num_shards_tensor ) ;
372+ Tensor shard_prefix = null ;
373+ tf_with ( ops . device ( device ) , _ =>
374+ {
375+ shard_prefix = sharded_filename ( tmp_checkpoint_prefix , shard , num_shards_tensor ) ;
376+ } ) ;
369377 saved_prefixes . Add ( shard_prefix ) ;
370- sharded_saves . Add ( saver . save ( shard_prefix , options ) ) ;
378+ tf_with ( ops . device ( device ) , _ =>
379+ {
380+ sharded_saves . Add ( saver . save ( shard_prefix , options ) ) ;
381+ } ) ;
371382 }
372383 using ( var controller = ops . control_dependencies ( sharded_saves . ToArray ( ) ) )
373384 {
374385 string merge_device = string . IsNullOrEmpty ( options . experimental_io_device ) ? last_device : options . experimental_io_device ;
375- tf . device ( merge_device ) ;
376- return gen_ops . merge_v2_checkpoints ( saved_prefixes . ToArray ( ) , tf . constant ( file_prefix ) , delete_old_dirs : true ) ;
386+ return tf_with ( ops . device ( merge_device ) , _ =>
387+ {
388+ return gen_ops . merge_v2_checkpoints ( saved_prefixes . ToArray ( ) , tf . constant ( file_prefix ) , delete_old_dirs : true ) ;
389+ } ) ;
377390 }
378391 }
379392
@@ -407,54 +420,56 @@ IDictionary<string, Operation> restore_func()
407420 {
408421 var device = single_saver . Key ;
409422 var saver = single_saver . Value ;
410- tf . device ( device ) ;
411- var restored_tensor_dict = saver . restore ( file_prefix , options ) ;
412-
413- foreach ( var pair in restored_tensor_dict )
423+ tf_with ( ops . device ( device ) , _ =>
414424 {
415- var checkpoint_key = pair . Key ;
416- var slice_and_tensor = pair . Value ;
417- foreach ( var item in slice_and_tensor )
425+ var restored_tensor_dict = saver . restore ( file_prefix , options ) ;
426+
427+ foreach ( var pair in restored_tensor_dict )
418428 {
419- var slice_spec = item . Key ;
420- var tensor = item . Value ;
421- var restore_fn = _keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] ;
422- var internal_dict = restore_fn_inputs . SetDefault ( restore_fn , new Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > ( ) ) ;
423- if ( ! string . IsNullOrEmpty ( slice_spec ) )
429+ var checkpoint_key = pair . Key ;
430+ var slice_and_tensor = pair . Value ;
431+ foreach ( var item in slice_and_tensor )
424432 {
425- if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
433+ var slice_spec = item . Key ;
434+ var tensor = item . Value ;
435+ var restore_fn = _keys_to_restore_fn [ ( checkpoint_key , slice_spec ) ] ;
436+ var internal_dict = restore_fn_inputs . SetDefault ( restore_fn , new Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > ( ) ) ;
437+ if ( ! string . IsNullOrEmpty ( slice_spec ) )
426438 {
427- Dictionary < string , Tensor > dict = new ( ) ;
428- dict [ slice_spec ] = tensor ;
429- internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( dict ) ;
439+ if ( ! internal_dict . ContainsKey ( checkpoint_key ) )
440+ {
441+ Dictionary < string , Tensor > dict = new ( ) ;
442+ dict [ slice_spec ] = tensor ;
443+ internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( dict ) ;
444+ }
445+ else
446+ {
447+ internal_dict [ checkpoint_key ] . GetValue < IDictionary < string , Tensor > > ( ) [ slice_spec ] = tensor ;
448+ }
430449 }
431450 else
432451 {
433- internal_dict [ checkpoint_key ] . GetValue < IDictionary < string , Tensor > > ( ) [ slice_spec ] = tensor ;
452+ internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( tensor ) ;
434453 }
435- }
436- else
437- {
438- internal_dict [ checkpoint_key ] = new Maybe < Tensor , IDictionary < string , Tensor > > ( tensor ) ;
439- }
440- restore_fn_input_count [ restore_fn ] -- ;
454+ restore_fn_input_count [ restore_fn ] -- ;
441455
442- if ( restore_fn_input_count [ restore_fn ] == 0 )
443- {
444- Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > restored_tensors = new ( ) ;
445- foreach ( var input in restore_fn_inputs [ restore_fn ] )
456+ if ( restore_fn_input_count [ restore_fn ] == 0 )
446457 {
447- restored_tensors [ TrackableUtils . extract_local_name ( input . Key ) ] = input . Value ;
448- }
449- var ret = restore_fn . DynamicInvoke ( restored_tensors ) ;
450- if ( ret is IDictionary < string , Operation > )
451- {
452- var dict = ( IDictionary < string , Operation > ) ret ;
453- restore_ops = restore_ops . Concat ( dict ) . ToDictionary ( x => x . Key , x => x . Value ) ;
458+ Dictionary < string , Maybe < Tensor , IDictionary < string , Tensor > > > restored_tensors = new ( ) ;
459+ foreach ( var input in restore_fn_inputs [ restore_fn ] )
460+ {
461+ restored_tensors [ TrackableUtils . extract_local_name ( input . Key ) ] = input . Value ;
462+ }
463+ var ret = restore_fn . DynamicInvoke ( restored_tensors ) ;
464+ if ( ret is IDictionary < string , Operation > )
465+ {
466+ var dict = ( IDictionary < string , Operation > ) ret ;
467+ restore_ops = restore_ops . Concat ( dict ) . ToDictionary ( x => x . Key , x => x . Value ) ;
468+ }
454469 }
455470 }
456471 }
457- }
472+ } ) ;
458473 }
459474
460475 foreach ( var item in _registered_savers )
@@ -500,21 +515,25 @@ public SaverDef to_proto()
500515 private Tensor _traced_save ( Tensor file_prefix )
501516 {
502517 var save_op = save ( file_prefix ) ;
503- tf . device ( "cpu:0" ) ;
504- using ( ops . control_dependencies ( new object [ ] { save_op } ) )
518+ return tf_with ( ops . device ( "cpu:0" ) , _ =>
505519 {
506- return array_ops . identity ( file_prefix ) ;
507- }
520+ return tf_with ( ops . control_dependencies ( new object [ ] { save_op } ) , __ =>
521+ {
522+ return array_ops . identity ( file_prefix ) ;
523+ } ) ;
524+ } ) ;
508525 }
509526
510527 private Tensor _traced_restore ( Tensor file_prefix )
511528 {
512529 var restore_op = restore ( file_prefix ) ;
513- tf . device ( "cpu:0" ) ;
514- using ( ops . control_dependencies ( restore_op . Values . ToArray ( ) ) )
530+ return tf_with ( ops . device ( "cpu:0" ) , _ =>
515531 {
516- return array_ops . identity ( file_prefix ) ;
517- }
532+ return tf_with ( ops . control_dependencies ( restore_op . Values . ToArray ( ) ) , __ =>
533+ {
534+ return array_ops . identity ( file_prefix ) ;
535+ } ) ;
536+ } ) ;
518537 }
519538
520539 public static MultiDeviceSaver from_saveables ( IEnumerable < MySaveableObject > saveables , IDictionary < string , IDictionary < string , Trackable > > ? registered_savers = null , bool call_with_mapped_captures = false )
0 commit comments