Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions code/convolutional_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2, 2)):
# store parameters of this layer
self.params = [self.W, self.b]

# keep track of model input
self.input = input


def evaluate_lenet5(learning_rate=0.1, n_epochs=200,
dataset='mnist.pkl.gz',
Expand Down
3 changes: 3 additions & 0 deletions code/logistic_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __init__(self, input, n_in, n_out):
# symbolic form
self.y_pred = T.argmax(self.p_y_given_x, axis=1)

# keep track of model input
self.input = input

def negative_log_likelihood(self, y):
"""Return the negative log-likelihood of the prediction of this model
under a given target distribution.
Expand Down
33 changes: 33 additions & 0 deletions code/logistic_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def __init__(self, input, n_in, n_out):
# parameters of the model
self.params = [self.W, self.b]

# keep track of model input
self.input = input

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if this is needed for the other models? That way, all pickled files will have the needed information.

def negative_log_likelihood(self, y):
"""Return the mean of the negative log-likelihood of the prediction
of this model under a given target distribution.
Expand Down Expand Up @@ -415,6 +418,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
)
)

# save the best model
with open('best_model.pkl', 'w') as f:
cPickle.dump(classifier, f)

if patience <= iter:
done_looping = True
break
Expand All @@ -433,5 +440,31 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
os.path.split(__file__)[1] +
' ran for %.1fs' % ((end_time - start_time)))


def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""

# load the saved model
classifier = cPickle.load(open('best_model.pkl'))

# compile a predictor function
predict_model = theano.function(
inputs=[classifier.input],
outputs=classifier.y_pred)

# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()

predicted_values = predict_model(test_set_x[:10])
print ("Predicted values for the first 10 examples in test set:")
print predicted_values


if __name__ == '__main__':
sgd_optimization_mnist()
3 changes: 3 additions & 0 deletions code/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def __init__(self, rng, input, n_in, n_hidden, n_out):
self.params = self.hiddenLayer.params + self.logRegressionLayer.params
# end-snippet-3

# keep track of model input
self.input = input


def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000,
dataset='mnist.pkl.gz', batch_size=20, n_hidden=500):
Expand Down
13 changes: 13 additions & 0 deletions doc/logreg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ approximately 1.936 epochs/sec and it took 75 epochs to reach a test
error of 7.489%. On the GPU the code does almost 10.0 epochs/sec. For this
instance we used a batch size of 600.


Prediction Using a Trained Model
+++++++++++++++++++++++++++++++

``sgd_optimization_mnist`` serialize and pickle the model each time new
lowest validation error is reached. We can reload this model and predict
labels of new data. ``predict`` function shows an example of how
this could be done.

.. literalinclude:: ../code/logistic_sgd.py
:pyobject: predict


.. rubric:: Footnotes

.. [#f1] For smaller datasets and simpler models, more sophisticated descent
Expand Down