|
| 1 | +#!/usr/bin/env python |
| 2 | +"""Example: train a model on CIFAR10.""" |
| 3 | +from __future__ import division, print_function |
| 4 | + |
| 5 | +import argparse |
| 6 | +import functools |
| 7 | +import logging |
| 8 | +import os.path |
| 9 | + |
| 10 | +from caffe2.python import brew, core, data_parallel_model, optimizer, workspace |
| 11 | +from caffe2.python.core import DataType |
| 12 | +from caffe2.python.model_helper import ModelHelper |
| 13 | +from caffe2.python.modeling.initializers import Initializer, pFP16Initializer |
| 14 | + |
| 15 | + |
| 16 | +logging.basicConfig() |
| 17 | + |
| 18 | +TRAIN_ENTRIES = 50000 |
| 19 | +TEST_ENTRIES = 10000 |
| 20 | +BATCH_SIZE = 100 |
| 21 | +EPOCHS = 10 |
| 22 | +DISPLAY = 100 |
| 23 | +ACCURACY_MIN = 0.7 |
| 24 | +ACCURACY_MAX = 0.8 |
| 25 | + |
| 26 | + |
| 27 | +def AddInputOps(model, reader, batch_size, dtype): |
| 28 | + """Add input ops.""" |
| 29 | + data, label = brew.image_input( |
| 30 | + model, [reader], ['data', 'label'], |
| 31 | + batch_size=batch_size, use_caffe_datum=False, use_gpu_transform=True, |
| 32 | + scale=32, crop=32, mirror=1, color=True, mean=128.0, |
| 33 | + output_type='float16' if dtype == DataType.FLOAT16 else 'float', |
| 34 | + is_test=False) |
| 35 | + data = model.StopGradient(data, data) |
| 36 | + |
| 37 | + |
| 38 | +def AddForwardPassOps(model, loss_scale, dtype): |
| 39 | + """Add forward pass ops and return a list of losses.""" |
| 40 | + initializer = (pFP16Initializer if dtype == DataType.FLOAT16 |
| 41 | + else Initializer) |
| 42 | + with brew.arg_scope([brew.conv, brew.fc], |
| 43 | + WeightInitializer=initializer, |
| 44 | + BiasInitializer=initializer): |
| 45 | + conv1 = brew.conv(model, 'data', 'conv1', 3, 32, 5, pad=2, |
| 46 | + weight_init=('GaussianFill', |
| 47 | + {'std': 0.0001, 'mean': 0.0})) |
| 48 | + pool1 = brew.max_pool(model, conv1, 'pool1', kernel=3, stride=2) |
| 49 | + relu1 = brew.relu(model, pool1, 'relu1') |
| 50 | + conv2 = brew.conv(model, relu1, 'conv2', 32, 32, 5, pad=2, |
| 51 | + weight_init=('GaussianFill', {'std': 0.01})) |
| 52 | + conv2 = brew.relu(model, conv2, conv2) |
| 53 | + pool2 = brew.average_pool(model, conv2, 'pool2', kernel=3, stride=2) |
| 54 | + conv3 = brew.conv(model, pool2, 'conv3', 32, 64, 5, pad=2, |
| 55 | + weight_init=('GaussianFill', {'std': 0.01})) |
| 56 | + conv3 = brew.relu(model, conv3, conv3) |
| 57 | + pool3 = brew.average_pool(model, conv3, 'pool3', kernel=3, stride=2) |
| 58 | + fc1 = brew.fc(model, pool3, 'fc1', 64 * 3 * 3, 64, |
| 59 | + weight_init=('GaussianFill', {'std': 0.1})) |
| 60 | + fc2 = brew.fc(model, fc1, 'fc2', 64, 10, |
| 61 | + weight_init=('GaussianFill', {'std': 0.1})) |
| 62 | + |
| 63 | + if dtype == DataType.FLOAT16: |
| 64 | + fc2 = model.net.HalfToFloat(fc2, fc2 + '_fp32') |
| 65 | + softmax, loss = model.SoftmaxWithLoss([fc2, 'label'], ['softmax', 'loss']) |
| 66 | + loss = model.Scale(loss, loss, scale=loss_scale) |
| 67 | + brew.accuracy(model, [softmax, 'label'], 'accuracy') |
| 68 | + return [loss] |
| 69 | + |
| 70 | + |
| 71 | +def AddOptimizerOps(model): |
| 72 | + """Add optimizer ops.""" |
| 73 | + optimizer.add_weight_decay(model, 0.004) |
| 74 | + stepsize = TRAIN_ENTRIES * EPOCHS // BATCH_SIZE |
| 75 | + optimizer.build_sgd( |
| 76 | + model, 0.001, |
| 77 | + policy='step', stepsize=stepsize, gamma=0.1, |
| 78 | + momentum=0.9, nesterov=False) |
| 79 | + |
| 80 | + |
| 81 | +def AddPostSyncOps(model): |
| 82 | + """Add ops which run after the initial parameter sync.""" |
| 83 | + for param_info in model.GetOptimizationParamInfo(model.GetParams()): |
| 84 | + if param_info.blob_copy is not None: |
| 85 | + # Ensure copies are in sync after initial broadcast |
| 86 | + model.param_init_net.HalfToFloat( |
| 87 | + param_info.blob, |
| 88 | + param_info.blob_copy[core.DataType.FLOAT] |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +def createTrainModel(lmdb_path, devices, dtype): |
| 93 | + """Create and return a training model, complete with training ops.""" |
| 94 | + model = ModelHelper(name='train', arg_scope={'order': 'NCHW'}) |
| 95 | + reader = model.CreateDB('train_reader', db=lmdb_path, db_type='lmdb') |
| 96 | + data_parallel_model.Parallelize_GPU( |
| 97 | + model, |
| 98 | + input_builder_fun=functools.partial( |
| 99 | + AddInputOps, reader=reader, |
| 100 | + batch_size=(BATCH_SIZE // len(devices)), dtype=dtype), |
| 101 | + forward_pass_builder_fun=functools.partial( |
| 102 | + AddForwardPassOps, dtype=dtype), |
| 103 | + optimizer_builder_fun=AddOptimizerOps, |
| 104 | + post_sync_builder_fun=AddPostSyncOps, |
| 105 | + devices=devices, use_nccl=True) |
| 106 | + workspace.RunNetOnce(model.param_init_net) |
| 107 | + workspace.CreateNet(model.net) |
| 108 | + return model |
| 109 | + |
| 110 | + |
| 111 | +def createTestModel(lmdb_path, devices, dtype): |
| 112 | + """Create and return a test model. Does not include training ops.""" |
| 113 | + model = ModelHelper(name='test', arg_scope={'order': 'NCHW'}, |
| 114 | + init_params=False) |
| 115 | + reader = model.CreateDB('test_reader', db=lmdb_path, db_type='lmdb') |
| 116 | + data_parallel_model.Parallelize_GPU( |
| 117 | + model, |
| 118 | + input_builder_fun=functools.partial( |
| 119 | + AddInputOps, reader=reader, |
| 120 | + batch_size=(BATCH_SIZE // len(devices)), dtype=dtype), |
| 121 | + forward_pass_builder_fun=functools.partial( |
| 122 | + AddForwardPassOps, dtype=dtype), |
| 123 | + param_update_builder_fun=None, |
| 124 | + devices=devices) |
| 125 | + workspace.RunNetOnce(model.param_init_net) |
| 126 | + workspace.CreateNet(model.net) |
| 127 | + return model |
| 128 | + |
| 129 | + |
| 130 | +def getArgs(): |
| 131 | + """Return command-line arguments.""" |
| 132 | + CURDIR = os.path.dirname(__file__) |
| 133 | + parser = argparse.ArgumentParser( |
| 134 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 135 | + parser.add_argument('--train-lmdb', help='Path to training LMDB', |
| 136 | + default=os.path.join(CURDIR, 'cifar10_train_lmdb')) |
| 137 | + parser.add_argument('--test-lmdb', help='Path to test LMDB', |
| 138 | + default=os.path.join(CURDIR, 'cifar10_test_lmdb')) |
| 139 | + parser.add_argument('--dtype', choices=['float', 'float16'], |
| 140 | + default='float', help='Data type used for training') |
| 141 | + parser.add_argument('--gpus', |
| 142 | + help='Comma separated list of GPU devices to use') |
| 143 | + parser.add_argument('--num_gpus', type=int, default=1, |
| 144 | + help='Number of GPU devices (instead of --gpus)') |
| 145 | + parser.add_argument('--all-gpus', action='store_true', |
| 146 | + help='Use all GPUs in the system') |
| 147 | + args = parser.parse_args() |
| 148 | + |
| 149 | + args.dtype = (DataType.FLOAT16 if args.dtype == 'float16' |
| 150 | + else DataType.FLOAT) |
| 151 | + |
| 152 | + if args.all_gpus: |
| 153 | + args.num_gpus = workspace.NumCudaDevices() |
| 154 | + args.gpus = range(args.num_gpus) |
| 155 | + else: |
| 156 | + if args.gpus is not None: |
| 157 | + args.gpus = [int(x) for x in args.gpus.split(',')] |
| 158 | + args.num_gpus = len(args.gpus) |
| 159 | + else: |
| 160 | + args.gpus = range(args.num_gpus) |
| 161 | + args.num_gpus = args.num_gpus |
| 162 | + return args |
| 163 | + |
| 164 | + |
| 165 | +def main(args): |
| 166 | + """Train and test.""" |
| 167 | + train_model = createTrainModel(args.train_lmdb, args.gpus, args.dtype) |
| 168 | + test_model = createTestModel(args.test_lmdb, args.gpus, args.dtype) |
| 169 | + |
| 170 | + train_iter_per_epoch = TRAIN_ENTRIES // BATCH_SIZE |
| 171 | + test_iter_per_epoch = TEST_ENTRIES // BATCH_SIZE |
| 172 | + scope_prefix = 'gpu_%d/' % args.gpus[0] |
| 173 | + |
| 174 | + for epoch in range(1, EPOCHS + 1): |
| 175 | + # Train |
| 176 | + for iteration in range(1, train_iter_per_epoch + 1): |
| 177 | + workspace.RunNet(train_model.net.Proto().name) |
| 178 | + if not iteration % DISPLAY: |
| 179 | + loss = workspace.FetchBlob(scope_prefix + 'loss') |
| 180 | + print("Epoch %d/%d, iteration %4d/%d, loss=%f" % ( |
| 181 | + epoch, EPOCHS, iteration, train_iter_per_epoch, loss)) |
| 182 | + |
| 183 | + # Test |
| 184 | + losses = [] |
| 185 | + accuracies = [] |
| 186 | + for _ in range(test_iter_per_epoch): |
| 187 | + workspace.RunNet(test_model.net.Proto().name) |
| 188 | + # Take average values across all GPUs |
| 189 | + losses.append(sum( |
| 190 | + workspace.FetchBlob('gpu_%d/loss' % g) for g in args.gpus |
| 191 | + ) / len(args.gpus)) |
| 192 | + accuracies.append(sum( |
| 193 | + workspace.FetchBlob('gpu_%d/accuracy' % g) for g in args.gpus |
| 194 | + ) / len(args.gpus)) |
| 195 | + |
| 196 | + loss = sum(losses) / len(losses) |
| 197 | + accuracy = sum(accuracies) / len(accuracies) |
| 198 | + print("Test loss: %f, accuracy: %f" % (loss, accuracy)) |
| 199 | + |
| 200 | + if accuracy < ACCURACY_MIN or accuracy > ACCURACY_MAX: |
| 201 | + raise RuntimeError( |
| 202 | + "Final accuracy %f is not in the expected range [%f, %f]" % |
| 203 | + (accuracy, ACCURACY_MIN, ACCURACY_MAX)) |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == '__main__': |
| 207 | + core.GlobalInit(['caffe2', '--caffe2_log_level=0']) |
| 208 | + main(getArgs()) |
0 commit comments