Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 3e72cc5

Browse files
densmirnAlexanderKalistratov
authored andcommitted
Overload df.rolling.max() (#476)
* Overload df.rolling.max() * Add perf.test for df.rolling.max() * Make comment more clear in test for df.rolling.max
1 parent 42fc172 commit 3e72cc5

File tree

4 files changed

+88
-14
lines changed

4 files changed

+88
-14
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2020, Intel Corporation All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
#
10+
# Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
16+
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
17+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
18+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
22+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
23+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24+
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
# *****************************************************************************
26+
27+
import pandas as pd
28+
from numba import njit
29+
30+
31+
@njit
32+
def df_rolling_max():
33+
df = pd.DataFrame({'A': [4, 3, 5, 2, 6], 'B': [-4, -3, -5, -2, -6]})
34+
out_df = df.rolling(3).max()
35+
36+
# Expect DataFrame of
37+
# {'A': [NaN, NaN, 5.0, 5.0, 6.0], 'B': [NaN, NaN, -3.0, -2.0, -2.0]}
38+
return out_df
39+
40+
41+
print(df_rolling_max())

sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ def sdc_pandas_dataframe_rolling_count(self):
149149
return gen_df_rolling_method_impl('count', self)
150150

151151

152+
@sdc_overload_method(DataFrameRollingType, 'max')
153+
def sdc_pandas_dataframe_rolling_max(self):
154+
155+
ty_checker = TypeChecker('Method rolling.max().')
156+
ty_checker.check(self, DataFrameRollingType)
157+
158+
return gen_df_rolling_method_impl('max', self)
159+
160+
152161
@sdc_overload_method(DataFrameRollingType, 'min')
153162
def sdc_pandas_dataframe_rolling_min(self):
154163

@@ -185,6 +194,13 @@ def sdc_pandas_dataframe_rolling_min(self):
185194
'extra_params': ''
186195
})
187196

197+
sdc_pandas_dataframe_rolling_max.__doc__ = sdc_pandas_dataframe_rolling_docstring_tmpl.format(**{
198+
'method_name': 'max',
199+
'example_caption': 'Calculate the rolling maximum.',
200+
'limitations_block': '',
201+
'extra_params': ''
202+
})
203+
188204
sdc_pandas_dataframe_rolling_min.__doc__ = sdc_pandas_dataframe_rolling_docstring_tmpl.format(**{
189205
'method_name': 'min',
190206
'example_caption': 'Calculate the rolling minimum.',

sdc/tests/test_rolling.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -583,14 +583,30 @@ def test_impl(obj, window, min_periods):
583583
ref_result = test_impl(obj, window, min_periods)
584584
assert_equal(jit_result, ref_result)
585585

586+
def _test_rolling_max(self, obj):
587+
def test_impl(obj, window, min_periods):
588+
return obj.rolling(window, min_periods).max()
589+
590+
hpat_func = self.jit(test_impl)
591+
assert_equal = self._get_assert_equal(obj)
592+
593+
# python implementation crashes if window = 0, jit works correctly
594+
for window in range(1, len(obj) + 2):
595+
for min_periods in range(window + 1):
596+
with self.subTest(obj=obj, window=window,
597+
min_periods=min_periods):
598+
jit_result = hpat_func(obj, window, min_periods)
599+
ref_result = test_impl(obj, window, min_periods)
600+
assert_equal(jit_result, ref_result)
601+
586602
def _test_rolling_min(self, obj):
587603
def test_impl(obj, window, min_periods):
588604
return obj.rolling(window, min_periods).min()
589605

590606
hpat_func = self.jit(test_impl)
591607
assert_equal = self._get_assert_equal(obj)
592608

593-
# TODO: fix the issue when window = 0
609+
# python implementation crashes if window = 0, jit works correctly
594610
for window in range(1, len(obj) + 2):
595611
for min_periods in range(window + 1):
596612
with self.subTest(obj=obj, window=window,
@@ -661,6 +677,15 @@ def test_df_rolling_count(self):
661677

662678
self._test_rolling_count(df)
663679

680+
@skip_sdc_jit('DataFrame.rolling.max() unsupported')
681+
def test_df_rolling_max(self):
682+
all_data = test_global_input_data_float64
683+
length = min(len(d) for d in all_data)
684+
data = {n: d[:length] for n, d in zip(string.ascii_uppercase, all_data)}
685+
df = pd.DataFrame(data)
686+
687+
self._test_rolling_max(df)
688+
664689
@skip_sdc_jit('DataFrame.rolling.min() unsupported')
665690
def test_df_rolling_min(self):
666691
all_data = test_global_input_data_float64
@@ -893,23 +918,11 @@ def test_impl(series, window, min_periods):
893918

894919
@skip_sdc_jit('Series.rolling.max() unsupported Series index')
895920
def test_series_rolling_max(self):
896-
def test_impl(series, window, min_periods):
897-
return series.rolling(window, min_periods).max()
898-
899-
hpat_func = self.jit(test_impl)
900-
901921
all_data = test_global_input_data_float64
902922
indices = [list(range(len(data)))[::-1] for data in all_data]
903923
for data, index in zip(all_data, indices):
904924
series = pd.Series(data, index, name='A')
905-
# TODO: fix the issue when window = 0
906-
for window in range(1, len(series) + 2):
907-
for min_periods in range(window + 1):
908-
with self.subTest(series=series, window=window,
909-
min_periods=min_periods):
910-
jit_result = hpat_func(series, window, min_periods)
911-
ref_result = test_impl(series, window, min_periods)
912-
pd.testing.assert_series_equal(jit_result, ref_result)
925+
self._test_rolling_max(series)
913926

914927
@skip_sdc_jit('Series.rolling.mean() unsupported Series index')
915928
def test_series_rolling_mean(self):

sdc/tests/tests_perf/test_perf_df_rolling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def setUpClass(cls):
8686
cls.total_data_length = {
8787
'apply': [2 * 10 ** 5],
8888
'count': [8 * 10 ** 5],
89+
'max': [2 * 10 ** 5],
8990
'min': [2 * 10 ** 5],
9091
}
9192

@@ -143,5 +144,8 @@ def test_df_rolling_apply_mean(self):
143144
def test_df_rolling_count(self):
144145
self._test_df_rolling_method('count')
145146

147+
def test_df_rolling_max(self):
148+
self._test_df_rolling_method('max')
149+
146150
def test_df_rolling_min(self):
147151
self._test_df_rolling_method('min')

0 commit comments

Comments
 (0)