|
5 | 5 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cnn.py |
6 | 6 | CPU run command: |
7 | 7 | python cnn.py |
| 8 | +
|
| 9 | +2016.06.06更新: |
| 10 | +这份代码是keras开发初期写的,当时keras还没有现在这么流行,文档也还没那么丰富,所以我当时写了一些简单的教程。 |
| 11 | +现在keras的API也发生了一些的变化,建议及推荐直接上keras.io看更加详细的教程。 |
| 12 | +
|
8 | 13 | ''' |
9 | 14 | #导入各种用到的模块组件 |
10 | 15 | from __future__ import absolute_import |
|
19 | 24 | from six.moves import range |
20 | 25 | from data import load_data |
21 | 26 | import random |
| 27 | +import numpy as np |
| 28 | + |
| 29 | +np.random.seed(1024) # for reproducibility |
22 | 30 |
|
23 | 31 |
|
24 | 32 |
|
|
46 | 54 | #border_mode可以是valid或者full,具体看这里说明:http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d |
47 | 55 | #激活函数用tanh |
48 | 56 | #你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5)) |
49 | | -model.add(Convolution2D(4, 5, 5, border_mode='valid',input_shape=data.shape[-3:])) |
| 57 | +model.add(Convolution2D(4, 5, 5, border_mode='valid',input_shape=(1,28,28))) |
50 | 58 | model.add(Activation('tanh')) |
51 | 59 |
|
52 | 60 |
|
|
82 | 90 | ############## |
83 | 91 | #使用SGD + momentum |
84 | 92 | #model.compile里的参数loss就是损失函数(目标函数) |
85 | | -sgd = SGD(l2=0.0,lr=0.05, decay=1e-6, momentum=0.9, nesterov=True) |
86 | | -model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical") |
| 93 | +sgd = SGD(lr=0.05, decay=1e-6, momentum=0.9, nesterov=True) |
| 94 | +model.compile(loss='categorical_crossentropy', optimizer=sgd) |
87 | 95 |
|
88 | 96 |
|
89 | 97 | #调用fit方法,就是一个训练过程. 训练的epoch数设为10,batch_size为100. |
90 | 98 | #数据经过随机打乱shuffle=True。verbose=1,训练过程中输出的信息,0、1、2三种方式都可以,无关紧要。show_accuracy=True,训练时每一个epoch都输出accuracy。 |
91 | 99 | #validation_split=0.2,将20%的数据作为验证集。 |
92 | | -model.fit(data, label, batch_size=100, nb_epoch=10,shuffle=True,verbose=1,show_accuracy=True,validation_split=0.2) |
| 100 | +model.fit(data, label, batch_size=100, nb_epoch=10,shuffle=True,verbose=1,validation_split=0.2) |
93 | 101 |
|
94 | 102 |
|
95 | 103 | """ |
|
123 | 131 | progbar.add(X_batch.shape[0], values=[("train loss", loss),("accuracy:", accuracy)] ) |
124 | 132 |
|
125 | 133 | """ |
126 | | - |
|
0 commit comments