1- using System ;
2- using System . Collections . Generic ;
3- using System . IO ;
4- using System . Linq ;
5- using System . Runtime . InteropServices ;
6- using System . Text ;
7- using Tensorflow . Util ;
1+ using Tensorflow . Util ;
82
93namespace Tensorflow . Checkpoint
104{
11- public class CheckpointReader : SafeTensorflowHandle
5+ sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
126 {
7+ public SafeCheckpointReaderHandle ( ) : base ( )
8+ {
9+
10+ }
11+ public SafeCheckpointReaderHandle ( IntPtr handle ) : base ( handle )
12+ {
13+
14+ }
15+
16+ protected override bool ReleaseHandle ( )
17+ {
18+ //if (handle != IntPtr.Zero)
19+ //{
20+ // c_api.TF_DeleteCheckpointReader(this);
21+ //}
22+ return true ;
23+ }
24+ }
25+ public class CheckpointReader
26+ {
27+ private SafeCheckpointReaderHandle _handle ;
1328 public Dictionary < string , TF_DataType > VariableToDataTypeMap { get ; set ; }
1429 public Dictionary < string , Shape > VariableToShapeMap { get ; set ; }
1530
1631 public CheckpointReader ( string filename )
1732 {
1833 Status status = new Status ( ) ;
19- handle = c_api . TF_NewCheckpointReader ( filename , status . Handle ) ;
34+ _handle = c_api . TF_NewCheckpointReader ( filename , status . Handle ) ;
2035 status . Check ( true ) ;
2136 ReadAllShapeAndType ( ) ;
2237 }
2338
2439 public int HasTensor ( string name )
2540 {
26- return c_api . TF_CheckpointReaderHasTensor ( handle , name ) ;
41+ return c_api . TF_CheckpointReaderHasTensor ( _handle , name ) ;
2742 }
2843
2944 /// <summary>
@@ -33,45 +48,39 @@ public int HasTensor(string name)
3348 /// <returns></returns>
3449 public string GetVariable ( int index )
3550 {
36- return c_api . TF_CheckpointReaderGetVariable ( handle , index ) ;
51+ return c_api . StringPiece ( c_api . TF_CheckpointReaderGetVariable ( _handle , index ) ) ;
3752 }
3853
3954 public int Size ( )
4055 {
41- return c_api . TF_CheckpointReaderSize ( handle ) ;
56+ return c_api . TF_CheckpointReaderSize ( _handle ) ;
4257 }
4358
4459 public TF_DataType GetVariableDataType ( string name )
4560 {
46- return c_api . TF_CheckpointReaderGetVariableDataType ( handle , name ) ;
61+ return c_api . TF_CheckpointReaderGetVariableDataType ( _handle , name ) ;
4762 }
4863
4964 public Shape GetVariableShape ( string name )
5065 {
51- // TODO(Rinne): Change it to a constant.
5266 int num_dims = GetVariableNumDims ( name ) ;
5367 long [ ] dims = new long [ num_dims ] ;
5468 Status status = new Status ( ) ;
55- c_api . TF_CheckpointReaderGetVariableShape ( handle , name , dims , num_dims , status . Handle ) ;
69+ c_api . TF_CheckpointReaderGetVariableShape ( _handle , name , dims , num_dims , status . Handle ) ;
5670 status . Check ( true ) ;
5771 return new Shape ( dims ) ;
5872 }
5973
6074 public int GetVariableNumDims ( string name )
6175 {
62- return c_api . TF_CheckpointReaderGetVariableNumDims ( handle , name ) ;
76+ return c_api . TF_CheckpointReaderGetVariableNumDims ( _handle , name ) ;
6377 }
6478
6579 public unsafe Tensor GetTensor ( string name , TF_DataType dtype = TF_DataType . DtInvalid )
6680 {
6781 Status status = new Status ( ) ;
68- var tensor = c_api . TF_CheckpointReaderGetTensor ( handle , name , status . Handle ) ;
82+ var tensor = c_api . TF_CheckpointReaderGetTensor ( _handle , name , status . Handle ) ;
6983 status . Check ( true ) ;
70- var shape = GetVariableShape ( name ) ;
71- if ( dtype == TF_DataType . DtInvalid )
72- {
73- dtype = GetVariableDataType ( name ) ;
74- }
7584 return new Tensor ( tensor ) ;
7685 }
7786
@@ -89,16 +98,5 @@ private void ReadAllShapeAndType()
8998 VariableToShapeMap [ name ] = shape ;
9099 }
91100 }
92-
93- protected override bool ReleaseHandle ( )
94- {
95- c_api . TF_DeleteCheckpointReader ( handle ) ;
96- return true ;
97- }
98-
99- public void Dispose ( )
100- {
101- c_api . TF_DeleteCheckpointReader ( handle ) ;
102- }
103101 }
104102}
0 commit comments