Skip to content

Commit 970e34e

Browse files
committed
Update cnn.py
1 parent 1ead15f commit 970e34e

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

  • DeepLearning Tutorials/keras_usage

DeepLearning Tutorials/keras_usage/cnn.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cnn.py
66
CPU run command:
77
python cnn.py
8+
9+
2016.06.06更新:
10+
这份代码是keras开发初期写的,当时keras还没有现在这么流行,文档也还没那么丰富,所以我当时写了一些简单的教程。
11+
现在keras的API也发生了一些的变化,建议及推荐直接上keras.io看更加详细的教程。
12+
813
'''
914
#导入各种用到的模块组件
1015
from __future__ import absolute_import
@@ -19,6 +24,9 @@
1924
from six.moves import range
2025
from data import load_data
2126
import random
27+
import numpy as np
28+
29+
np.random.seed(1024) # for reproducibility
2230

2331

2432

@@ -46,7 +54,7 @@
4654
#border_mode可以是valid或者full,具体看这里说明:http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d
4755
#激活函数用tanh
4856
#你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5))
49-
model.add(Convolution2D(4, 5, 5, border_mode='valid',input_shape=data.shape[-3:]))
57+
model.add(Convolution2D(4, 5, 5, border_mode='valid',input_shape=(1,28,28)))
5058
model.add(Activation('tanh'))
5159

5260

@@ -82,14 +90,14 @@
8290
##############
8391
#使用SGD + momentum
8492
#model.compile里的参数loss就是损失函数(目标函数)
85-
sgd = SGD(l2=0.0,lr=0.05, decay=1e-6, momentum=0.9, nesterov=True)
86-
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")
93+
sgd = SGD(lr=0.05, decay=1e-6, momentum=0.9, nesterov=True)
94+
model.compile(loss='categorical_crossentropy', optimizer=sgd)
8795

8896

8997
#调用fit方法,就是一个训练过程. 训练的epoch数设为10,batch_size为100.
9098
#数据经过随机打乱shuffle=True。verbose=1,训练过程中输出的信息,0、1、2三种方式都可以,无关紧要。show_accuracy=True,训练时每一个epoch都输出accuracy。
9199
#validation_split=0.2,将20%的数据作为验证集。
92-
model.fit(data, label, batch_size=100, nb_epoch=10,shuffle=True,verbose=1,show_accuracy=True,validation_split=0.2)
100+
model.fit(data, label, batch_size=100, nb_epoch=10,shuffle=True,verbose=1,validation_split=0.2)
93101

94102

95103
"""
@@ -123,4 +131,3 @@
123131
progbar.add(X_batch.shape[0], values=[("train loss", loss),("accuracy:", accuracy)] )
124132
125133
"""
126-

0 commit comments

Comments
 (0)