11"""Learn to estimate functions from examples. (Chapters 18-20)"""
22
33from utils import (
4- removeall , unique , product , mode , argmax , argmax_random_tie , isclose ,
4+ removeall , unique , product , mode , argmax , argmax_random_tie , isclose , gaussian ,
55 dotproduct , vector_add , scalar_vector_product , weighted_sample_with_replacement ,
66 weighted_sampler , num_or_str , normalize , clip , sigmoid , print_table , DataFile
77)
1111import math
1212import random
1313
14- from statistics import mean
14+ from statistics import mean , stdev
1515from collections import defaultdict
1616
1717# ______________________________________________________________________________
@@ -178,6 +178,45 @@ def remove_examples(self, value=""):
178178 self .examples = [x for x in self .examples if value not in x ]
179179 self .update_values ()
180180
181+ def split_values_by_classes (self ):
182+ """Split values into buckets according to their class."""
183+ buckets = defaultdict (lambda : [])
184+ target_names = self .values [self .target ]
185+
186+ for v in self .examples :
187+ item = [a for a in v if a not in target_names ] # Remove target from item
188+ buckets [v [self .target ]].append (item ) # Add item to bucket of its class
189+
190+ return buckets
191+
192+ def find_means_and_deviations (self ):
193+ """Finds the means and standard deviations of self.dataset.
194+ means : A dictionary for each class/target. Holds a list of the means
195+ of the features for the class.
196+ deviations: A dictionary for each class/target. Holds a list of the sample
197+ standard deviations of the features for the class."""
198+ target_names = self .values [self .target ]
199+ feature_numbers = len (self .inputs )
200+
201+ item_buckets = self .split_values_by_classes ()
202+
203+ means = defaultdict (lambda : [0 for i in range (feature_numbers )])
204+ deviations = defaultdict (lambda : [0 for i in range (feature_numbers )])
205+
206+ for t in target_names :
207+ # Find all the item feature values for item in class t
208+ features = [[] for i in range (feature_numbers )]
209+ for item in item_buckets [t ]:
210+ features = [features [i ] + [item [i ]] for i in range (feature_numbers )]
211+
212+ # Calculate means and deviations fo the class
213+ for i in range (feature_numbers ):
214+ means [t ][i ] = mean (features [i ])
215+ deviations [t ][i ] = stdev (features [i ])
216+
217+ return means , deviations
218+
219+
181220 def __repr__ (self ):
182221 return '<DataSet({}): {:d} examples, {:d} attributes>' .format (
183222 self .name , len (self .examples ), len (self .attrs ))
@@ -267,15 +306,22 @@ def predict(example):
267306# ______________________________________________________________________________
268307
269308
270- def NaiveBayesLearner (dataset ):
309+ def NaiveBayesLearner (dataset , continuous = True ):
310+ if (continuous ):
311+ return NaiveBayesContinuous (dataset )
312+ else :
313+ return NaiveBayesDiscrete (dataset )
314+
315+
316+ def NaiveBayesDiscrete (dataset ):
271317 """Just count how many times each value of each input attribute
272318 occurs, conditional on the target value. Count the different
273319 target values too."""
274320
275- targetvals = dataset .values [dataset .target ]
276- target_dist = CountingProbDist (targetvals )
321+ target_vals = dataset .values [dataset .target ]
322+ target_dist = CountingProbDist (target_vals )
277323 attr_dists = {(gv , attr ): CountingProbDist (dataset .values [attr ])
278- for gv in targetvals
324+ for gv in target_vals
279325 for attr in dataset .inputs }
280326 for example in dataset .examples :
281327 targetval = example [dataset .target ]
@@ -290,7 +336,29 @@ def class_probability(targetval):
290336 return (target_dist [targetval ] *
291337 product (attr_dists [targetval , attr ][example [attr ]]
292338 for attr in dataset .inputs ))
293- return argmax (targetvals , key = class_probability )
339+ return argmax (target_vals , key = class_probability )
340+
341+ return predict
342+
343+
344+ def NaiveBayesContinuous (dataset ):
345+ """Count how many times each target value occurs.
346+ Also, find the means and deviations of input attribute values for each target value."""
347+ means , deviations = dataset .find_means_and_deviations ()
348+
349+ target_vals = dataset .values [dataset .target ]
350+ target_dist = CountingProbDist (target_vals )
351+
352+ def predict (example ):
353+ """Predict the target value for example. Consider each possible value,
354+ and pick the most likely by looking at each attribute independently."""
355+ def class_probability (targetval ):
356+ prob = target_dist [targetval ]
357+ for attr in dataset .inputs :
358+ prob *= gaussian (means [targetval ][attr ], deviations [targetval ][attr ], example [attr ])
359+ return prob
360+
361+ return argmax (target_vals , key = class_probability )
294362
295363 return predict
296364
0 commit comments