Skip to content

Commit 97ea80a

Browse files
committed
Adding confidence bounds to local linear regression
1 parent 904ff3f commit 97ea80a

8 files changed

Lines changed: 640 additions & 10 deletions

bigml/linear.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@
4545
import copy
4646
import json
4747

48+
try:
49+
import numpy as np
50+
from scipy.stats import t as student_t
51+
except ImportError:
52+
raise ImportError("Failed to import the numpy and scipy modules needed"
53+
" for this class.")
54+
4855
from functools import cmp_to_key
4956

5057
from bigml.api import FINISHED
@@ -67,6 +74,7 @@
6774
"items": "items"}
6875

6976
CATEGORICAL = "categorical"
77+
CONFIDENCE = 0.95
7078

7179
DUMMY = "dummy"
7280
CONTRAST = "contrast"
@@ -109,9 +117,12 @@ def __init__(self, linear_regression, api=None):
109117
self.coefficients = []
110118
self.data_field_types = {}
111119
self.field_codings = {}
112-
self.numeric_fields = {}
113120
self.bias = None
114-
121+
self.xtx = []
122+
self.inv_xtx = None
123+
self.mean_squared_error = None
124+
self.number_of_parameters = None
125+
self.number_of_samples = None
115126

116127
self.resource_id, linear_regression = get_resource_dict( \
117128
linear_regression, "linearregression", api=api)
@@ -143,6 +154,7 @@ def __init__(self, linear_regression, api=None):
143154
field_id for field_id, _ in
144155
sorted(self.fields.items(),
145156
key=lambda x: x[1].get("column_number"))]
157+
self.coeff_ids = self.input_fields[:]
146158
self.coefficients = linear_regression_info.get( \
147159
'coefficients', [])
148160
self.bias = linear_regression_info.get('bias', True)
@@ -155,16 +167,26 @@ def __init__(self, linear_regression, api=None):
155167
numerics=True)
156168
self.field_codings = linear_regression_info.get( \
157169
'field_codings', {})
158-
print "**before", self.field_codings
159170
self.format_field_codings()
160-
print "**after", self.field_codings
161171
for field_id in self.field_codings:
162172
if field_id not in fields and \
163173
field_id in self.inverted_fields:
164174
self.field_codings.update( \
165175
{self.inverted_fields[field_id]: \
166176
self.field_codings[field_id]})
167177
del self.field_codings[field_id]
178+
stats = linear_regression_info["stats"]
179+
if stats is not None and "xtx" in stats:
180+
self.xtx = stats["xtx"][:]
181+
self.mean_squared_error = stats["mean_squared_error"]
182+
self.number_of_parameters = stats["number_of_parameters"]
183+
self.number_of_samples = stats["number_of_samples"]
184+
# to be used in predictions
185+
self.t_crit = student_t.interval( \
186+
CONFIDENCE,
187+
self.number_of_samples - self.number_of_parameters)[1]
188+
self.inv_xtx = list(np.linalg.inv(np.array(self.xtx)))
189+
168190
else:
169191
raise Exception("The linear regression isn't finished yet")
170192
else:
@@ -173,7 +195,7 @@ def __init__(self, linear_regression, api=None):
173195
" in the resource:\n\n%s" %
174196
linear_regression)
175197

176-
def expand_input(self, input_data, unique_terms):
198+
def expand_input(self, input_data, unique_terms, compact=False):
177199
""" Creates an input array with the values in input_data and
178200
unique_terms and the following rules:
179201
- fields are ordered as input_fields
@@ -187,7 +209,7 @@ def expand_input(self, input_data, unique_terms):
187209
as numerics.
188210
"""
189211
input_array = []
190-
for index, field_id in enumerate(self.input_fields):
212+
for index, field_id in enumerate(self.coeff_ids):
191213
field = self.fields[field_id]
192214
optype = field["optype"]
193215
missing = False
@@ -216,7 +238,7 @@ def expand_input(self, input_data, unique_terms):
216238

217239
if optype == CATEGORICAL:
218240
new_inputs = self.categorical_encoding( \
219-
new_inputs, field_id)
241+
new_inputs, field_id, compact)
220242

221243
input_array.extend(new_inputs)
222244

@@ -225,7 +247,7 @@ def expand_input(self, input_data, unique_terms):
225247

226248
return input_array
227249

228-
def categorical_encoding(self, inputs, field_id):
250+
def categorical_encoding(self, inputs, field_id, compact):
229251
"""Returns the prediction and the confidence intervals
230252
231253
input_data: Input data to be predicted
@@ -235,10 +257,17 @@ def categorical_encoding(self, inputs, field_id):
235257

236258
projections = self.field_codings[field_id].get( \
237259
CONTRAST, self.field_codings[field_id].get(OTHER))
238-
print "***", projections, new_inputs
239260
if projections is not None:
240261
new_inputs = flatten(dot(projections, [new_inputs]))
241262

263+
if compact and self.field_codings[field_id].get(DUMMY) is not None:
264+
dummy_class = self.field_codings[field_id][DUMMY]
265+
index = self.categories[field_id].index(dummy_class)
266+
cat_new_inputs = new_inputs[0: index]
267+
if len(new_inputs) > (index + 1):
268+
cat_new_inputs.extend(new_inputs[index + 1 :])
269+
new_inputs = cat_new_inputs
270+
242271
return new_inputs
243272

244273
def predict(self, input_data, full=False):
@@ -275,19 +304,46 @@ def predict(self, input_data, full=False):
275304

276305
# Creates an input vector with the values for all expanded fields.
277306
input_array = self.expand_input(new_data, unique_terms)
307+
compact_input_array = self.expand_input(new_data, unique_terms, True)
278308

279309
prediction = dot([flatten(self.coefficients)], [input_array])[0][0]
280310

281311
result = {
282312
"prediction": prediction}
313+
if self.inv_xtx is not None:
314+
result.update({"confidence_bounds": self.confidence_bounds( \
315+
compact_input_array)})
283316

284317
if full:
285-
result.update({'unused_fields': unused_fields})
318+
result.update({"unused_fields": unused_fields})
286319
else:
287320
result = result["prediction"]
288321

289322
return result
290323

324+
325+
def confidence_bounds(self, input_array):
326+
"""Computes the confidence interval for the prediction
327+
328+
"""
329+
product = dot(dot([input_array], self.inv_xtx),
330+
[input_array])[0][0]
331+
try:
332+
333+
if self.mean_squared_error != 0:
334+
confidence_interval = self.t_crit * math.sqrt( \
335+
self.mean_squared_error * product)
336+
prediction_interval = self.t_crit * math.sqrt( \
337+
self.mean_squared_error * (product + 1))
338+
else:
339+
confidence_interval, prediction_interval = (0, 0)
340+
except Exception:
341+
confidence_interval, prediction_interval = (0, 0)
342+
343+
return {"confidence_interval": confidence_interval,
344+
"prediction_interval": prediction_interval}
345+
346+
291347
def format_field_codings(self):
292348
""" Changes the field codings format to the dict notation
293349

bigml/modelfields.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def add_terms(self, categories=False, numerics=False):
198198
if categories and field['optype'] == 'categorical':
199199
self.categories[field_id] = [category for \
200200
[category, _] in field['summary']['categories']]
201+
if field['optype'] == 'datetime' and \
202+
hasattr(self, coeff_ids):
203+
self.coeff_id = [coeff_id for coeff_id in self.coeff_ids \
204+
if coeff_id != field_id]
201205
if numerics and hasattr(self, "missing_numerics") and \
202206
self.missing_numerics and field['optype'] == 'numeric' \
203207
and hasattr(self, "numeric_fields"):

bigml/tests/create_linear_steps.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# -*- coding: utf-8 -*-
2+
#!/usr/bin/env python
3+
#
4+
# Copyright 2019 BigML
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
7+
# not use this file except in compliance with the License. You may obtain
8+
# a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
18+
import time
19+
import json
20+
import os
21+
from datetime import datetime, timedelta
22+
from world import world
23+
from nose.tools import eq_, assert_less
24+
25+
from bigml.api import HTTP_CREATED
26+
from bigml.api import HTTP_ACCEPTED
27+
from bigml.api import FINISHED
28+
from bigml.api import FAULTY
29+
from bigml.api import get_status
30+
31+
from read_linear_steps import i_get_the_linear_regression
32+
33+
34+
#@step(r'the linear name is "(.*)"')
35+
def i_check_linear_name(step, name):
36+
linear_name = world.linear_regression['name']
37+
eq_(name, linear_name)
38+
39+
#@step(r'I create a Linear Regression from a dataset$')
40+
def i_create_a_linear_regression_from_dataset(step):
41+
dataset = world.dataset.get('resource')
42+
resource = world.api.create_linear_regression( \
43+
dataset, {'name': 'new linear regression'})
44+
world.status = resource['code']
45+
eq_(world.status, HTTP_CREATED)
46+
world.location = resource['location']
47+
world.linear_regression = resource['object']
48+
world.linear_regressions.append(resource['resource'])
49+
50+
51+
#@step(r'I create a Linear Regression from a dataset$')
52+
def i_create_a_linear_regression_with_params(step, params):
53+
i_create_a_linear_regression_with_objective_and_params(step, None, params)
54+
55+
56+
#@step(r'I create a Linear Regression with objective and params$')
57+
def i_create_a_linear_regression_with_objective_and_params(step,
58+
objective,
59+
params):
60+
params = json.loads(params)
61+
if objective is not None:
62+
params.update({"objective_field": objective})
63+
dataset = world.dataset.get('resource')
64+
resource = world.api.create_linear_regression(dataset, params)
65+
world.status = resource['code']
66+
eq_(world.status, HTTP_CREATED)
67+
world.location = resource['location']
68+
world.linear_regression = resource['object']
69+
world.linear_regressions.append(resource['resource'])
70+
71+
def i_create_a_linear_regression(step):
72+
i_create_a_linear_regression_from_dataset(step)
73+
74+
75+
#@step(r'I update the linear regression name to "(.*)"$')
76+
def i_update_linear_regression_name(step, name):
77+
resource = world.api.update_linear_regression( \
78+
world.linear_regression['resource'],
79+
{'name': name})
80+
world.status = resource['code']
81+
eq_(world.status, HTTP_ACCEPTED)
82+
world.location = resource['location']
83+
world.linear_regression = resource['object']
84+
85+
86+
#@step(r'I wait until the linear regression status code is either (\d) or (-\d) less than (\d+)')
87+
def wait_until_linear_regression_status_code_is(step, code1, code2, secs):
88+
start = datetime.utcnow()
89+
delta = int(secs) * world.delta
90+
linear_regression_id = world.linear_regression['resource']
91+
i_get_the_linear_regression(step, linear_regression_id)
92+
status = get_status(world.linear_regression)
93+
while (status['code'] != int(code1) and
94+
status['code'] != int(code2)):
95+
time.sleep(3)
96+
assert_less(datetime.utcnow() - start, timedelta(seconds=delta))
97+
i_get_the_linear_regression(step, linear_regression_id)
98+
status = get_status(world.linear_regression)
99+
eq_(status['code'], int(code1))
100+
101+
102+
#@step(r'I wait until the linear is ready less than (\d+)')
103+
def the_linear_regression_is_finished_in_less_than(step, secs):
104+
wait_until_linear_regression_status_code_is(step, FINISHED, FAULTY, secs)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# -*- coding: utf-8 -*-
2+
#!/usr/bin/env python
3+
#
4+
# Copyright 2018-2019 BigML
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
7+
# not use this file except in compliance with the License. You may obtain
8+
# a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
18+
import time
19+
import json
20+
import os
21+
from datetime import datetime, timedelta
22+
from world import world
23+
from nose.tools import eq_, assert_less
24+
25+
from bigml.api import HTTP_CREATED
26+
from bigml.api import HTTP_ACCEPTED
27+
from bigml.api import FINISHED
28+
from bigml.api import FAULTY
29+
from bigml.api import get_status
30+
31+
from read_pca_steps import i_get_the_pca
32+
33+
34+
#@step(r'the pca name is "(.*)"')
35+
def i_check_pca_name(step, name):
36+
pca_name = world.pca['name']
37+
eq_(name, pca_name)
38+
39+
#@step(r'I create a PCA from a dataset$')
40+
def i_create_a_pca_from_dataset(step):
41+
dataset = world.dataset.get('resource')
42+
resource = world.api.create_pca(dataset, {'name': 'new PCA'})
43+
world.status = resource['code']
44+
eq_(world.status, HTTP_CREATED)
45+
world.location = resource['location']
46+
world.pca = resource['object']
47+
world.pcas.append(resource['resource'])
48+
49+
50+
#@step(r'I create a PCA from a dataset$')
51+
def i_create_a_pca_with_params(step, params):
52+
params = json.loads(params)
53+
dataset = world.dataset.get('resource')
54+
resource = world.api.create_pca(dataset, params)
55+
world.status = resource['code']
56+
eq_(world.status, HTTP_CREATED)
57+
world.location = resource['location']
58+
world.pca = resource['object']
59+
world.pcas.append(resource['resource'])
60+
61+
def i_create_a_pca(step):
62+
i_create_a_pca_from_dataset(step)
63+
64+
65+
#@step(r'I update the PCA name to "(.*)"$')
66+
def i_update_pca_name(step, name):
67+
resource = world.api.update_pca(world.pca['resource'],
68+
{'name': name})
69+
world.status = resource['code']
70+
eq_(world.status, HTTP_ACCEPTED)
71+
world.location = resource['location']
72+
world.pca = resource['object']
73+
74+
75+
#@step(r'I wait until the PCA status code is either (\d) or (-\d) less than (\d+)')
76+
def wait_until_pca_status_code_is(step, code1, code2, secs):
77+
start = datetime.utcnow()
78+
delta = int(secs) * world.delta
79+
pca_id = world.pca['resource']
80+
i_get_the_pca(step, pca_id)
81+
status = get_status(world.pca)
82+
while (status['code'] != int(code1) and
83+
status['code'] != int(code2)):
84+
time.sleep(3)
85+
assert_less(datetime.utcnow() - start, timedelta(seconds=delta))
86+
i_get_the_pca(step, pca_id)
87+
status = get_status(world.pca)
88+
eq_(status['code'], int(code1))
89+
90+
91+
#@step(r'I wait until the PCA is ready less than (\d+)')
92+
def the_pca_is_finished_in_less_than(step, secs):
93+
wait_until_pca_status_code_is(step, FINISHED, FAULTY, secs)

0 commit comments

Comments
 (0)