|
| 1 | +#!/usr/bin/env python |
| 2 | +# Lab 10 MNIST and MLP with dropout |
| 3 | + |
| 4 | +import argparse |
| 5 | +import numpy as np |
| 6 | +import chainer |
| 7 | +import chainer.functions as F |
| 8 | +import chainer.links as L |
| 9 | + |
| 10 | +from chainer import training, Variable |
| 11 | +from chainer import datasets, iterators, optimizers |
| 12 | +from chainer import Chain |
| 13 | +from chainer.training import extensions |
| 14 | + |
| 15 | + |
| 16 | +class ModernMLP(chainer.Chain): |
| 17 | + # Define model to be called later by L.Classifer() |
| 18 | + # Basic MLP |
| 19 | + |
| 20 | + def __init__(self, n_units, n_out): |
| 21 | + super(ModernMLP, self).__init__( |
| 22 | + l1=L.Linear(None, n_units), |
| 23 | + l2=L.Linear(None, n_units), |
| 24 | + l3=L.Linear(None, n_out) |
| 25 | + ) |
| 26 | + |
| 27 | + def __call__(self, x): |
| 28 | + # Add dropout, and use ReLU for activation function. |
| 29 | + # dropout: |
| 30 | + # This function drops input elements randomly with probability |
| 31 | + # ``ratio`` and scales the remaining elements by factor |
| 32 | + # ``1 / (1 - ratio)``. In testing mode, it does nothing and |
| 33 | + # just returns ``x``. |
| 34 | + # source: http://docs.chainer.org/en/latest/_modules/chainer/functions/noise/dropout.html?highlight=dropout |
| 35 | + h = F.dropout(F.relu(self.l1(x)), ratio=0.3, train=True) |
| 36 | + h = F.dropout(F.relu(self.l2(h)), ratio=0.3, train=True) |
| 37 | + return self.l3(h) |
| 38 | + |
| 39 | + |
| 40 | +def main(): |
| 41 | + # Introduce argparse for clarity and organization. |
| 42 | + # Starting to use higher capacity models, thus set up for GPU. |
| 43 | + parser = argparse.ArgumentParser(description='Chainer-Tutorial: MLP') |
| 44 | + parser.add_argument('--batch_size', '-b', type=int, default=128, |
| 45 | + help='Number of samples in each mini-batch') |
| 46 | + parser.add_argument('--epoch', '-e', type=int, default=100, |
| 47 | + help='Number of times to train on data set') |
| 48 | + parser.add_argument('--gpu', '-g', type=int, default=-1, |
| 49 | + help='GPU ID: -1 indicates CPU') |
| 50 | + parser.add_argument('--frequency', '-f', type=int, default=-1, |
| 51 | + help='Frequency of taking a snapshot') |
| 52 | + parser.add_argument('--resume', '-r', default='', |
| 53 | + help='Resume the training from snapshot') |
| 54 | + args = parser.parse_args() |
| 55 | + |
| 56 | + # Load mnist data |
| 57 | + # http://docs.chainer.org/en/latest/reference/datasets.html |
| 58 | + train, test = chainer.datasets.get_mnist() |
| 59 | + |
| 60 | + # Define iterators. |
| 61 | + train_iter = chainer.iterators.SerialIterator(train, args.batch_size) |
| 62 | + test_iter = chainer.iterators.SerialIterator(test, args.batch_size, |
| 63 | + repeat=False, shuffle=False) |
| 64 | + |
| 65 | + # Initialize model: Loss function defaults to softmax_cross_entropy. |
| 66 | + # 784 is dimension of the inputs, 625 is n_units in hidden layer |
| 67 | + # and 10 is the output dimension. |
| 68 | + model = L.Classifier(ModernMLP(625, 10)) |
| 69 | + |
| 70 | + # Set up GPU usage if necessary. args.gpu is a condition as well as an |
| 71 | + # identification when passed to get_device(). |
| 72 | + if args.gpu >= 0: |
| 73 | + chainer.cuda.get_device(args.gpu).use() |
| 74 | + model.to_gpu() |
| 75 | + |
| 76 | + # Define optimizer (SGD, Adam, RMSprop, etc) |
| 77 | + # http://docs.chainer.org/en/latest/reference/optimizers.html |
| 78 | + # RMSprop default parameter setting: |
| 79 | + # lr=0.01, alpha=0.99, eps=1e-8 |
| 80 | + optimizer = chainer.optimizers.RMSprop() |
| 81 | + optimizer.setup(model) |
| 82 | + |
| 83 | + # Set up trainer |
| 84 | + updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) |
| 85 | + trainer = training.Trainer(updater, (args.epoch, 'epoch')) |
| 86 | + |
| 87 | + # Evaluate the model at end of each epoch |
| 88 | + trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) |
| 89 | + |
| 90 | + # Dump a computational graph from 'loss' variable at the first iteration |
| 91 | + # The "main" refers to the target link of the "main" optimizer. |
| 92 | + trainer.extend(extensions.dump_graph('main/loss')) |
| 93 | + |
| 94 | + # Helper functions (extensions) to monitor progress on stdout. |
| 95 | + report_params = [ |
| 96 | + 'epoch', |
| 97 | + 'main/loss', |
| 98 | + 'validation/main/loss', |
| 99 | + 'main/accuracy', |
| 100 | + 'validation/main/accuracy', |
| 101 | + 'elapsed_time' |
| 102 | + ] |
| 103 | + trainer.extend(extensions.LogReport()) |
| 104 | + trainer.extend(extensions.PrintReport(report_params)) |
| 105 | + trainer.extend(extensions.ProgressBar()) |
| 106 | + |
| 107 | + # Here we add a bit more boiler plate code to help in output of useful |
| 108 | + # information in related to training. Very intuitive and great for post |
| 109 | + # analysis. |
| 110 | + # source: |
| 111 | + # https://github.com/pfnet/chainer/blob/master/examples/mnist/train_mnist.py |
| 112 | + |
| 113 | + # Take a snapshot for each specified epoch |
| 114 | + frequency = args.epoch if args.frequency == -1 else max(1, args.frequency) |
| 115 | + trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) |
| 116 | + |
| 117 | + # Write a log of evaluation statistics for each epoch |
| 118 | + trainer.extend(extensions.LogReport()) |
| 119 | + |
| 120 | + # Save two plot images to the result dir |
| 121 | + if extensions.PlotReport.available(): |
| 122 | + trainer.extend( |
| 123 | + extensions.PlotReport( |
| 124 | + ['main/loss', 'validation/main/loss'], |
| 125 | + 'epoch', file_name='loss.png')) |
| 126 | + trainer.extend( |
| 127 | + extensions.PlotReport( |
| 128 | + ['main/accuracy', 'validation/main/accuracy'], |
| 129 | + 'epoch', file_name='accuracy.png')) |
| 130 | + |
| 131 | + if args.resume: |
| 132 | + # Resume from a snapshot (NumPy NPZ format and HDF5 format available) |
| 133 | + # http://docs.chainer.org/en/latest/reference/serializers.html |
| 134 | + chainer.serializers.load_npz(args.resume, trainer) |
| 135 | + |
| 136 | + # Run trainer |
| 137 | + trainer.run() |
| 138 | + |
| 139 | + |
| 140 | +if __name__ == "__main__": |
| 141 | + main() |
| 142 | + |
| 143 | + |
| 144 | +""" |
| 145 | +Expected output with 1 gpu. |
| 146 | +
|
| 147 | +epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time |
| 148 | +... |
| 149 | +90 0.217452 0.965264 0.958189 0.941456 294.61 |
| 150 | +91 0.196134 1.14531 0.959089 0.944917 297.859 |
| 151 | +92 0.203648 0.956059 0.957148 0.943928 301.109 |
| 152 | +93 0.20284 1.02199 0.960021 0.948378 304.362 |
| 153 | +94 0.195888 1.18072 0.958905 0.945609 307.619 |
| 154 | +95 0.199831 1.2245 0.958356 0.94195 310.879 |
| 155 | +96 0.200486 1.10434 0.960186 0.943038 314.151 |
| 156 | +97 0.202059 1.43919 0.960421 0.943335 317.447 |
| 157 | +98 0.221666 0.947955 0.959305 0.946994 320.745 |
| 158 | +99 0.200717 1.35896 0.961504 0.943137 324.038 |
| 159 | +100 0.182234 0.935365 0.962039 0.946301 327.323 |
| 160 | +""" |
0 commit comments