@@ -21,23 +21,26 @@ public static Datasets read_data_sets(string train_dir,
2121 TF_DataType dtype = TF_DataType . TF_FLOAT ,
2222 bool reshape = true ,
2323 int validation_size = 5000 ,
24+ int test_size = 5000 ,
2425 string source_url = DEFAULT_SOURCE_URL )
2526 {
27+ var train_size = validation_size * 2 ;
28+
2629 Web . Download ( source_url + TRAIN_IMAGES , train_dir , TRAIN_IMAGES ) ;
2730 Compress . ExtractGZip ( Path . Join ( train_dir , TRAIN_IMAGES ) , train_dir ) ;
28- var train_images = extract_images ( Path . Join ( train_dir , TRAIN_IMAGES . Split ( '.' ) [ 0 ] ) ) ;
31+ var train_images = extract_images ( Path . Join ( train_dir , TRAIN_IMAGES . Split ( '.' ) [ 0 ] ) , limit : train_size ) ;
2932
3033 Web . Download ( source_url + TRAIN_LABELS , train_dir , TRAIN_LABELS ) ;
3134 Compress . ExtractGZip ( Path . Join ( train_dir , TRAIN_LABELS ) , train_dir ) ;
32- var train_labels = extract_labels ( Path . Join ( train_dir , TRAIN_LABELS . Split ( '.' ) [ 0 ] ) , one_hot : one_hot ) ;
35+ var train_labels = extract_labels ( Path . Join ( train_dir , TRAIN_LABELS . Split ( '.' ) [ 0 ] ) , one_hot : one_hot , limit : train_size ) ;
3336
3437 Web . Download ( source_url + TEST_IMAGES , train_dir , TEST_IMAGES ) ;
3538 Compress . ExtractGZip ( Path . Join ( train_dir , TEST_IMAGES ) , train_dir ) ;
36- var test_images = extract_images ( Path . Join ( train_dir , TEST_IMAGES . Split ( '.' ) [ 0 ] ) ) ;
39+ var test_images = extract_images ( Path . Join ( train_dir , TEST_IMAGES . Split ( '.' ) [ 0 ] ) , limit : test_size ) ;
3740
3841 Web . Download ( source_url + TEST_LABELS , train_dir , TEST_LABELS ) ;
3942 Compress . ExtractGZip ( Path . Join ( train_dir , TEST_LABELS ) , train_dir ) ;
40- var test_labels = extract_labels ( Path . Join ( train_dir , TEST_LABELS . Split ( '.' ) [ 0 ] ) , one_hot : one_hot ) ;
43+ var test_labels = extract_labels ( Path . Join ( train_dir , TEST_LABELS . Split ( '.' ) [ 0 ] ) , one_hot : one_hot , limit : test_size ) ;
4144
4245 int end = train_images . shape [ 0 ] ;
4346 var validation_images = train_images [ np . arange ( validation_size ) ] ;
@@ -52,14 +55,15 @@ public static Datasets read_data_sets(string train_dir,
5255 return new Datasets ( train , validation , test ) ;
5356 }
5457
55- public static NDArray extract_images ( string file )
58+ public static NDArray extract_images ( string file , int ? limit = null )
5659 {
5760 using ( var bytestream = new FileStream ( file , FileMode . Open ) )
5861 {
5962 var magic = _read32 ( bytestream ) ;
6063 if ( magic != 2051 )
6164 throw new ValueError ( $ "Invalid magic number { magic } in MNIST image file: { file } ") ;
62- var num_images = _read32 ( bytestream ) ;
65+ var num_images = _read32 ( bytestream ) ;
66+ num_images = limit == null ? num_images : Math . Min ( num_images , ( uint ) limit ) ;
6367 var rows = _read32 ( bytestream ) ;
6468 var cols = _read32 ( bytestream ) ;
6569 var buf = new byte [ rows * cols * num_images ] ;
@@ -70,14 +74,15 @@ public static NDArray extract_images(string file)
7074 }
7175 }
7276
73- public static NDArray extract_labels ( string file , bool one_hot = false , int num_classes = 10 )
77+ public static NDArray extract_labels ( string file , bool one_hot = false , int num_classes = 10 , int ? limit = null )
7478 {
7579 using ( var bytestream = new FileStream ( file , FileMode . Open ) )
7680 {
7781 var magic = _read32 ( bytestream ) ;
7882 if ( magic != 2049 )
7983 throw new ValueError ( $ "Invalid magic number { magic } in MNIST label file: { file } ") ;
8084 var num_items = _read32 ( bytestream ) ;
85+ num_items = limit == null ? num_items : Math . Min ( num_items , ( uint ) limit ) ;
8186 var buf = new byte [ num_items ] ;
8287 bytestream . Read ( buf , 0 , buf . Length ) ;
8388 var labels = np . frombuffer ( buf , np . uint8 ) ;
0 commit comments