@@ -109,6 +109,9 @@ def __init__(self, input, n_in, n_out):
109109 # parameters of the model
110110 self .params = [self .W , self .b ]
111111
112+ # keep track of model input
113+ self .input = input
114+
112115 def negative_log_likelihood (self , y ):
113116 """Return the mean of the negative log-likelihood of the prediction
114117 of this model under a given target distribution.
@@ -415,6 +418,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
415418 )
416419 )
417420
421+ # save the best model
422+ with open ('best_model.pkl' , 'w' ) as f :
423+ cPickle .dump (classifier , f )
424+
418425 if patience <= iter :
419426 done_looping = True
420427 break
@@ -433,5 +440,31 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
433440 os .path .split (__file__ )[1 ] +
434441 ' ran for %.1fs' % ((end_time - start_time )))
435442
443+
444+ def predict ():
445+ """
446+ An example of how to load a trained model and use it
447+ to predict labels.
448+ """
449+
450+ # load the saved model
451+ classifier = cPickle .load (open ('best_model.pkl' ))
452+
453+ # compile a predictor function
454+ predict_model = theano .function (
455+ inputs = [classifier .input ],
456+ outputs = classifier .y_pred )
457+
458+ # We can test it on some examples from test test
459+ dataset = 'mnist.pkl.gz'
460+ datasets = load_data (dataset )
461+ test_set_x , test_set_y = datasets [2 ]
462+ test_set_x = test_set_x .get_value ()
463+
464+ predicted_values = predict_model (test_set_x [:10 ])
465+ print ("Predicted values for the first 10 examples in test set:" )
466+ print predicted_values
467+
468+
436469if __name__ == '__main__' :
437470 sgd_optimization_mnist ()
0 commit comments