@@ -1049,14 +1049,35 @@ def grade_learner(predict, tests):
10491049 return mean (int (predict (X ) == y ) for X , y in tests )
10501050
10511051
1052- def train_test_split (dataset , start , end ):
1053- """Reserve dataset.examples[start:end] for test; train on the remainder."""
1054- start = int (start )
1055- end = int (end )
1056- examples = dataset .examples
1057- train = examples [:start ] + examples [end :]
1058- val = examples [start :end ]
1059- return train , val
1052+ def train_test_split (dataset , start = None , end = None , test_split = None ):
1053+ """If you are giving 'start' and 'end' as a parameter,
1054+ then it will return testing set from index 'start' to 'end'
1055+ and rest for training.
1056+ If you give 'test_split' as parameter then it will first shuffle the
1057+ dataset then return test_split * 100% as testing set and rest as
1058+ training set.
1059+ """
1060+
1061+ if start == None and end != None :
1062+ raise ValueError ("'start' parameter is missing" )
1063+
1064+ if start != None and end == None :
1065+ raise ValueError ("'end' parameter is missing" )
1066+
1067+ if test_split == None :
1068+ examples = dataset .examples
1069+ train = examples [:start ] + examples [end :]
1070+ val = examples [start :end ]
1071+ return train , val
1072+ else :
1073+ examples = dataset .examples
1074+ total_size = len (examples )
1075+ val_size = int (total_size * test_split )
1076+ train_size = total_size - val_size
1077+ random .shuffle (examples )
1078+ train = examples [:train_size ]
1079+ val = examples [train_size :total_size ]
1080+ return train , val
10601081
10611082
10621083def cross_validation (learner , size , dataset , k = 10 , trials = 1 ):
0 commit comments