Skip to content

Commit fdcb256

Browse files
committed
Merge branch 'master' of github.com:hunkim/DeepLearningZeroToAll
2 parents 32bc4f3 + 7930416 commit fdcb256

1 file changed

Lines changed: 160 additions & 0 deletions

File tree

Chainer/chlab-10-3-mnist_modern.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/usr/bin/env python
2+
# Lab 10 MNIST and MLP with dropout
3+
4+
import argparse
5+
import numpy as np
6+
import chainer
7+
import chainer.functions as F
8+
import chainer.links as L
9+
10+
from chainer import training, Variable
11+
from chainer import datasets, iterators, optimizers
12+
from chainer import Chain
13+
from chainer.training import extensions
14+
15+
16+
class ModernMLP(chainer.Chain):
17+
# Define model to be called later by L.Classifer()
18+
# Basic MLP
19+
20+
def __init__(self, n_units, n_out):
21+
super(ModernMLP, self).__init__(
22+
l1=L.Linear(None, n_units),
23+
l2=L.Linear(None, n_units),
24+
l3=L.Linear(None, n_out)
25+
)
26+
27+
def __call__(self, x):
28+
# Add dropout, and use ReLU for activation function.
29+
# dropout:
30+
# This function drops input elements randomly with probability
31+
# ``ratio`` and scales the remaining elements by factor
32+
# ``1 / (1 - ratio)``. In testing mode, it does nothing and
33+
# just returns ``x``.
34+
# source: http://docs.chainer.org/en/latest/_modules/chainer/functions/noise/dropout.html?highlight=dropout
35+
h = F.dropout(F.relu(self.l1(x)), ratio=0.3, train=True)
36+
h = F.dropout(F.relu(self.l2(h)), ratio=0.3, train=True)
37+
return self.l3(h)
38+
39+
40+
def main():
41+
# Introduce argparse for clarity and organization.
42+
# Starting to use higher capacity models, thus set up for GPU.
43+
parser = argparse.ArgumentParser(description='Chainer-Tutorial: MLP')
44+
parser.add_argument('--batch_size', '-b', type=int, default=128,
45+
help='Number of samples in each mini-batch')
46+
parser.add_argument('--epoch', '-e', type=int, default=100,
47+
help='Number of times to train on data set')
48+
parser.add_argument('--gpu', '-g', type=int, default=-1,
49+
help='GPU ID: -1 indicates CPU')
50+
parser.add_argument('--frequency', '-f', type=int, default=-1,
51+
help='Frequency of taking a snapshot')
52+
parser.add_argument('--resume', '-r', default='',
53+
help='Resume the training from snapshot')
54+
args = parser.parse_args()
55+
56+
# Load mnist data
57+
# http://docs.chainer.org/en/latest/reference/datasets.html
58+
train, test = chainer.datasets.get_mnist()
59+
60+
# Define iterators.
61+
train_iter = chainer.iterators.SerialIterator(train, args.batch_size)
62+
test_iter = chainer.iterators.SerialIterator(test, args.batch_size,
63+
repeat=False, shuffle=False)
64+
65+
# Initialize model: Loss function defaults to softmax_cross_entropy.
66+
# 784 is dimension of the inputs, 625 is n_units in hidden layer
67+
# and 10 is the output dimension.
68+
model = L.Classifier(ModernMLP(625, 10))
69+
70+
# Set up GPU usage if necessary. args.gpu is a condition as well as an
71+
# identification when passed to get_device().
72+
if args.gpu >= 0:
73+
chainer.cuda.get_device(args.gpu).use()
74+
model.to_gpu()
75+
76+
# Define optimizer (SGD, Adam, RMSprop, etc)
77+
# http://docs.chainer.org/en/latest/reference/optimizers.html
78+
# RMSprop default parameter setting:
79+
# lr=0.01, alpha=0.99, eps=1e-8
80+
optimizer = chainer.optimizers.RMSprop()
81+
optimizer.setup(model)
82+
83+
# Set up trainer
84+
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
85+
trainer = training.Trainer(updater, (args.epoch, 'epoch'))
86+
87+
# Evaluate the model at end of each epoch
88+
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
89+
90+
# Dump a computational graph from 'loss' variable at the first iteration
91+
# The "main" refers to the target link of the "main" optimizer.
92+
trainer.extend(extensions.dump_graph('main/loss'))
93+
94+
# Helper functions (extensions) to monitor progress on stdout.
95+
report_params = [
96+
'epoch',
97+
'main/loss',
98+
'validation/main/loss',
99+
'main/accuracy',
100+
'validation/main/accuracy',
101+
'elapsed_time'
102+
]
103+
trainer.extend(extensions.LogReport())
104+
trainer.extend(extensions.PrintReport(report_params))
105+
trainer.extend(extensions.ProgressBar())
106+
107+
# Here we add a bit more boiler plate code to help in output of useful
108+
# information in related to training. Very intuitive and great for post
109+
# analysis.
110+
# source:
111+
# https://github.com/pfnet/chainer/blob/master/examples/mnist/train_mnist.py
112+
113+
# Take a snapshot for each specified epoch
114+
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
115+
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
116+
117+
# Write a log of evaluation statistics for each epoch
118+
trainer.extend(extensions.LogReport())
119+
120+
# Save two plot images to the result dir
121+
if extensions.PlotReport.available():
122+
trainer.extend(
123+
extensions.PlotReport(
124+
['main/loss', 'validation/main/loss'],
125+
'epoch', file_name='loss.png'))
126+
trainer.extend(
127+
extensions.PlotReport(
128+
['main/accuracy', 'validation/main/accuracy'],
129+
'epoch', file_name='accuracy.png'))
130+
131+
if args.resume:
132+
# Resume from a snapshot (NumPy NPZ format and HDF5 format available)
133+
# http://docs.chainer.org/en/latest/reference/serializers.html
134+
chainer.serializers.load_npz(args.resume, trainer)
135+
136+
# Run trainer
137+
trainer.run()
138+
139+
140+
if __name__ == "__main__":
141+
main()
142+
143+
144+
"""
145+
Expected output with 1 gpu.
146+
147+
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time
148+
...
149+
90 0.217452 0.965264 0.958189 0.941456 294.61
150+
91 0.196134 1.14531 0.959089 0.944917 297.859
151+
92 0.203648 0.956059 0.957148 0.943928 301.109
152+
93 0.20284 1.02199 0.960021 0.948378 304.362
153+
94 0.195888 1.18072 0.958905 0.945609 307.619
154+
95 0.199831 1.2245 0.958356 0.94195 310.879
155+
96 0.200486 1.10434 0.960186 0.943038 314.151
156+
97 0.202059 1.43919 0.960421 0.943335 317.447
157+
98 0.221666 0.947955 0.959305 0.946994 320.745
158+
99 0.200717 1.35896 0.961504 0.943137 324.038
159+
100 0.182234 0.935365 0.962039 0.946301 327.323
160+
"""

0 commit comments

Comments
 (0)