Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 333e843

Browse files
authored
Overload df.rolling.var() (#483)
* Overload df.rolling.var() * Fix a code style issue * Add perf.test for df.rolling.var * Minor changes in a test for df.rolling.var
1 parent 3b843fe commit 333e843

File tree

4 files changed

+123
-19
lines changed

4 files changed

+123
-19
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2020, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
import pandas as pd
28+
from numba import njit
29+
30+
31+
@njit
32+
def df_rolling_var():
33+
df = pd.DataFrame({'A': [4, 3, 5, 2, 6], 'B': [-4, -3, -5, -2, -6]})
34+
out_df = df.rolling(3).var()
35+
36+
# Expect DataFrame of
37+
# {'A': [NaN, NaN, 1.000000, 2.333333, 4.333333],
38+
# 'B': [NaN, NaN, 1.000000, 2.333333, 4.333333]}
39+
return out_df
40+
41+
42+
print(df_rolling_var())

sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,18 @@ def sdc_pandas_dataframe_rolling_sum(self):
400400
return gen_df_rolling_method_impl('sum', self)
401401

402402

403+
@sdc_overload_method(DataFrameRollingType, 'var')
404+
def sdc_pandas_dataframe_rolling_var(self, ddof=1):
405+
406+
ty_checker = TypeChecker('Method rolling.var().')
407+
ty_checker.check(self, DataFrameRollingType)
408+
409+
if not isinstance(ddof, (int, Integer, Omitted)):
410+
ty_checker.raise_exc(ddof, 'int', 'ddof')
411+
412+
return gen_df_rolling_method_impl('var', self, kws={'ddof': '1'})
413+
414+
403415
sdc_pandas_dataframe_rolling_apply.__doc__ = sdc_pandas_dataframe_rolling_docstring_tmpl.format(**{
404416
'method_name': 'apply',
405417
'example_caption': 'Calculate the rolling apply.',
@@ -538,3 +550,19 @@ def sdc_pandas_dataframe_rolling_sum(self):
538550
""",
539551
'extra_params': ''
540552
})
553+
554+
sdc_pandas_dataframe_rolling_var.__doc__ = sdc_pandas_dataframe_rolling_docstring_tmpl.format(**{
555+
'method_name': 'var',
556+
'example_caption': 'Calculate unbiased rolling variance.',
557+
'limitations_block':
558+
"""
559+
Limitations
560+
-----------
561+
DataFrame elements cannot be max/min float/integer. Otherwise SDC and Pandas results are different.
562+
""",
563+
'extra_params':
564+
"""
565+
ddof: :obj:`int`
566+
Delta Degrees of Freedom.
567+
"""
568+
})

sdc/tests/test_rolling.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def rolling_std_usecase(obj, window, min_periods, ddof):
5656
return obj.rolling(window, min_periods).std(ddof)
5757

5858

59-
def series_rolling_var_usecase(series, window, min_periods, ddof):
60-
return series.rolling(window, min_periods).var(ddof)
59+
def rolling_var_usecase(obj, window, min_periods, ddof):
60+
return obj.rolling(window, min_periods).var(ddof)
6161

6262

6363
class TestRolling(TestCase):
@@ -807,6 +807,29 @@ def test_impl(obj, window, min_periods):
807807
ref_result = test_impl(obj, window, min_periods)
808808
assert_equal(jit_result, ref_result)
809809

810+
def _test_rolling_var(self, obj):
811+
test_impl = rolling_var_usecase
812+
hpat_func = self.jit(test_impl)
813+
assert_equal = self._get_assert_equal(obj)
814+
815+
for window in range(0, len(obj) + 3, 2):
816+
for min_periods, ddof in product(range(0, window, 2), [0, 1]):
817+
with self.subTest(obj=obj, window=window,
818+
min_periods=min_periods, ddof=ddof):
819+
jit_result = hpat_func(obj, window, min_periods, ddof)
820+
ref_result = test_impl(obj, window, min_periods, ddof)
821+
assert_equal(jit_result, ref_result)
822+
823+
def _test_rolling_var_exception_unsupported_ddof(self, obj):
824+
test_impl = rolling_var_usecase
825+
hpat_func = self.jit(test_impl)
826+
827+
window, min_periods, invalid_ddof = 3, 2, '1'
828+
with self.assertRaises(TypingError) as raises:
829+
hpat_func(obj, window, min_periods, invalid_ddof)
830+
msg = 'Method rolling.var(). The object ddof\n given: unicode_type\n expected: int'
831+
self.assertIn(msg, str(raises.exception))
832+
810833
@skip_sdc_jit('DataFrame.rolling.min() unsupported exceptions')
811834
def test_df_rolling_unsupported_values(self):
812835
all_data = test_global_input_data_float64
@@ -1075,6 +1098,28 @@ def test_df_rolling_sum(self):
10751098

10761099
self._test_rolling_sum(df)
10771100

1101+
@skip_sdc_jit('DataFrame.rolling.var() unsupported')
1102+
def test_df_rolling_var(self):
1103+
all_data = [
1104+
list(range(10)), [1., -1., 0., 0.1, -0.1],
1105+
[1., np.inf, np.inf, -1., 0., np.inf, np.NINF, np.NINF],
1106+
[np.nan, np.inf, np.inf, np.nan, np.nan, np.nan, np.NINF, np.NZERO]
1107+
]
1108+
length = min(len(d) for d in all_data)
1109+
data = {n: d[:length] for n, d in zip(string.ascii_uppercase, all_data)}
1110+
df = pd.DataFrame(data)
1111+
1112+
self._test_rolling_var(df)
1113+
1114+
@skip_sdc_jit('DataFrame.rolling.var() unsupported exceptions')
1115+
def test_df_rolling_var_exception_unsupported_ddof(self):
1116+
all_data = [[1., -1., 0., 0.1, -0.1], [-1., 1., 0., -0.1, 0.1]]
1117+
length = min(len(d) for d in all_data)
1118+
data = {n: d[:length] for n, d in zip(string.ascii_uppercase, all_data)}
1119+
df = pd.DataFrame(data)
1120+
1121+
self._test_rolling_var_exception_unsupported_ddof(df)
1122+
10781123
@skip_sdc_jit('Series.rolling.min() unsupported exceptions')
10791124
def test_series_rolling_unsupported_values(self):
10801125
series = pd.Series(test_global_input_data_float64[0])
@@ -1335,9 +1380,6 @@ def test_series_rolling_sum(self):
13351380

13361381
@skip_sdc_jit('Series.rolling.var() unsupported Series index')
13371382
def test_series_rolling_var(self):
1338-
test_impl = series_rolling_var_usecase
1339-
hpat_func = self.jit(test_impl)
1340-
13411383
all_data = [
13421384
list(range(10)), [1., -1., 0., 0.1, -0.1],
13431385
[1., np.inf, np.inf, -1., 0., np.inf, np.NINF, np.NINF],
@@ -1346,24 +1388,12 @@ def test_series_rolling_var(self):
13461388
indices = [list(range(len(data)))[::-1] for data in all_data]
13471389
for data, index in zip(all_data, indices):
13481390
series = pd.Series(data, index, name='A')
1349-
for window in range(0, len(series) + 3, 2):
1350-
for min_periods, ddof in product(range(0, window, 2), [0, 1]):
1351-
with self.subTest(series=series, window=window,
1352-
min_periods=min_periods, ddof=ddof):
1353-
jit_result = hpat_func(series, window, min_periods, ddof)
1354-
ref_result = test_impl(series, window, min_periods, ddof)
1355-
pd.testing.assert_series_equal(jit_result, ref_result)
1391+
self._test_rolling_var(series)
13561392

13571393
@skip_sdc_jit('Series.rolling.var() unsupported exceptions')
13581394
def test_series_rolling_var_exception_unsupported_ddof(self):
1359-
test_impl = series_rolling_var_usecase
1360-
hpat_func = self.jit(test_impl)
1361-
13621395
series = pd.Series([1., -1., 0., 0.1, -0.1])
1363-
with self.assertRaises(TypingError) as raises:
1364-
hpat_func(series, 3, 2, '1')
1365-
msg = 'Method rolling.var(). The object ddof\n given: unicode_type\n expected: int'
1366-
self.assertIn(msg, str(raises.exception))
1396+
self._test_rolling_var_exception_unsupported_ddof(series)
13671397

13681398

13691399
if __name__ == "__main__":

sdc/tests/tests_perf/test_perf_df_rolling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def setUpClass(cls):
9898
'skew': [2 * 10 ** 5],
9999
'std': [2 * 10 ** 5],
100100
'sum': [2 * 10 ** 5],
101+
'var': [2 * 10 ** 5],
101102
}
102103

103104
def _test_jitted(self, pyfunc, record, *args, **kwargs):
@@ -209,3 +210,6 @@ def test_df_rolling_std(self):
209210

210211
def test_df_rolling_sum(self):
211212
self._test_df_rolling_method('sum')
213+
214+
def test_df_rolling_var(self):
215+
self._test_df_rolling_method('var')

0 commit comments

Comments
 (0)