1+ """Functions for downloading and reading MNIST data."""
2+ from __future__ import print_function
3+ import gzip
4+ import os
5+ import urllib
6+ import numpy
7+ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
8+ def maybe_download (filename , work_directory ):
9+ """Download the data from Yann's website, unless it's already here."""
10+ if not os .path .exists (work_directory ):
11+ os .mkdir (work_directory )
12+ filepath = os .path .join (work_directory , filename )
13+ if not os .path .exists (filepath ):
14+ filepath , _ = urllib .urlretrieve (SOURCE_URL + filename , filepath )
15+ statinfo = os .stat (filepath )
16+ print ('Succesfully downloaded' , filename , statinfo .st_size , 'bytes.' )
17+ return filepath
18+ def _read32 (bytestream ):
19+ dt = numpy .dtype (numpy .uint32 ).newbyteorder ('>' )
20+ return numpy .frombuffer (bytestream .read (4 ), dtype = dt )
21+ def extract_images (filename ):
22+ """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
23+ print ('Extracting' , filename )
24+ with gzip .open (filename ) as bytestream :
25+ magic = _read32 (bytestream )
26+ if magic != 2051 :
27+ raise ValueError (
28+ 'Invalid magic number %d in MNIST image file: %s' %
29+ (magic , filename ))
30+ num_images = _read32 (bytestream )
31+ rows = _read32 (bytestream )
32+ cols = _read32 (bytestream )
33+ buf = bytestream .read (rows * cols * num_images )
34+ data = numpy .frombuffer (buf , dtype = numpy .uint8 )
35+ data = data .reshape (num_images , rows , cols , 1 )
36+ return data
37+ def dense_to_one_hot (labels_dense , num_classes = 10 ):
38+ """Convert class labels from scalars to one-hot vectors."""
39+ num_labels = labels_dense .shape [0 ]
40+ index_offset = numpy .arange (num_labels ) * num_classes
41+ labels_one_hot = numpy .zeros ((num_labels , num_classes ))
42+ labels_one_hot .flat [index_offset + labels_dense .ravel ()] = 1
43+ return labels_one_hot
44+ def extract_labels (filename , one_hot = False ):
45+ """Extract the labels into a 1D uint8 numpy array [index]."""
46+ print ('Extracting' , filename )
47+ with gzip .open (filename ) as bytestream :
48+ magic = _read32 (bytestream )
49+ if magic != 2049 :
50+ raise ValueError (
51+ 'Invalid magic number %d in MNIST label file: %s' %
52+ (magic , filename ))
53+ num_items = _read32 (bytestream )
54+ buf = bytestream .read (num_items )
55+ labels = numpy .frombuffer (buf , dtype = numpy .uint8 )
56+ if one_hot :
57+ return dense_to_one_hot (labels )
58+ return labels
59+ class DataSet (object ):
60+ def __init__ (self , images , labels , fake_data = False ):
61+ if fake_data :
62+ self ._num_examples = 10000
63+ else :
64+ assert images .shape [0 ] == labels .shape [0 ], (
65+ "images.shape: %s labels.shape: %s" % (images .shape ,
66+ labels .shape ))
67+ self ._num_examples = images .shape [0 ]
68+ # Convert shape from [num examples, rows, columns, depth]
69+ # to [num examples, rows*columns] (assuming depth == 1)
70+ assert images .shape [3 ] == 1
71+ images = images .reshape (images .shape [0 ],
72+ images .shape [1 ] * images .shape [2 ])
73+ # Convert from [0, 255] -> [0.0, 1.0].
74+ images = images .astype (numpy .float32 )
75+ images = numpy .multiply (images , 1.0 / 255.0 )
76+ self ._images = images
77+ self ._labels = labels
78+ self ._epochs_completed = 0
79+ self ._index_in_epoch = 0
80+ @property
81+ def images (self ):
82+ return self ._images
83+ @property
84+ def labels (self ):
85+ return self ._labels
86+ @property
87+ def num_examples (self ):
88+ return self ._num_examples
89+ @property
90+ def epochs_completed (self ):
91+ return self ._epochs_completed
92+ def next_batch (self , batch_size , fake_data = False ):
93+ """Return the next `batch_size` examples from this data set."""
94+ if fake_data :
95+ fake_image = [1.0 for _ in xrange (784 )]
96+ fake_label = 0
97+ return [fake_image for _ in xrange (batch_size )], [
98+ fake_label for _ in xrange (batch_size )]
99+ start = self ._index_in_epoch
100+ self ._index_in_epoch += batch_size
101+ if self ._index_in_epoch > self ._num_examples :
102+ # Finished epoch
103+ self ._epochs_completed += 1
104+ # Shuffle the data
105+ perm = numpy .arange (self ._num_examples )
106+ numpy .random .shuffle (perm )
107+ self ._images = self ._images [perm ]
108+ self ._labels = self ._labels [perm ]
109+ # Start next epoch
110+ start = 0
111+ self ._index_in_epoch = batch_size
112+ assert batch_size <= self ._num_examples
113+ end = self ._index_in_epoch
114+ return self ._images [start :end ], self ._labels [start :end ]
115+ def read_data_sets (train_dir , fake_data = False , one_hot = False ):
116+ class DataSets (object ):
117+ pass
118+ data_sets = DataSets ()
119+ if fake_data :
120+ data_sets .train = DataSet ([], [], fake_data = True )
121+ data_sets .validation = DataSet ([], [], fake_data = True )
122+ data_sets .test = DataSet ([], [], fake_data = True )
123+ return data_sets
124+ TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
125+ TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
126+ TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
127+ TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
128+ VALIDATION_SIZE = 5000
129+ local_file = maybe_download (TRAIN_IMAGES , train_dir )
130+ train_images = extract_images (local_file )
131+ local_file = maybe_download (TRAIN_LABELS , train_dir )
132+ train_labels = extract_labels (local_file , one_hot = one_hot )
133+ local_file = maybe_download (TEST_IMAGES , train_dir )
134+ test_images = extract_images (local_file )
135+ local_file = maybe_download (TEST_LABELS , train_dir )
136+ test_labels = extract_labels (local_file , one_hot = one_hot )
137+ validation_images = train_images [:VALIDATION_SIZE ]
138+ validation_labels = train_labels [:VALIDATION_SIZE ]
139+ train_images = train_images [VALIDATION_SIZE :]
140+ train_labels = train_labels [VALIDATION_SIZE :]
141+ data_sets .train = DataSet (train_images , train_labels )
142+ data_sets .validation = DataSet (validation_images , validation_labels )
143+ data_sets .test = DataSet (test_images , test_labels )
144+ return data_sets
0 commit comments