Skip to content

Commit 104c4de

Browse files
committed
added testing as requested by issue aymericdamien#2
1 parent 21e2454 commit 104c4de

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

examples/2 - Basic Classifiers/linear_regression.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,23 @@
5555
"W=", sess.run(W), "b=", sess.run(b)
5656

5757
print "Optimization Finished!"
58-
print "cost=", sess.run(cost, feed_dict={X: train_X, Y: train_Y}), "W=", sess.run(W), "b=", sess.run(b)
58+
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
59+
print "Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n'
60+
61+
62+
# Testing example, as requested (Issue #2)
63+
test_X = numpy.asarray([6.83,4.668,8.9,7.91,5.7,8.7,3.1,2.1])
64+
test_Y = numpy.asarray([1.84,2.273,3.2,2.831,2.92,3.24,1.35,1.03])
65+
66+
print "Testing... (L2 loss Comparison)"
67+
testing_cost = sess.run(tf.reduce_sum(tf.pow(activation-Y, 2))/(2*test_X.shape[0]),
68+
feed_dict={X: test_X, Y: test_Y}) #same function as cost above
69+
print "Testing cost=", testing_cost
70+
print "Absolute l2 loss difference:", abs(training_cost - testing_cost)
5971

6072
#Graphic display
6173
plt.plot(train_X, train_Y, 'ro', label='Original data')
74+
plt.plot(test_X, test_Y, 'bo', label='Testing data')
6275
plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
6376
plt.legend()
6477
plt.show()

0 commit comments

Comments
 (0)