Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.
Prev Previous commit
Next Next commit
Fix type checking
  • Loading branch information
etotmeni committed Apr 14, 2020
commit ecce368f8cc4d5cf5071fbecc180cdfa48cd2fb7
19 changes: 13 additions & 6 deletions sdc/datatypes/hpat_pandas_dataframe_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,12 +1819,19 @@ def sdc_pandas_dataframe_accessor_getitem(self, idx):
accessor = self.accessor.literal_value

if accessor == 'at':
if isinstance(idx, types.Tuple) and isinstance(idx[1], types.Literal):
row = idx[0]
col = idx[1].literal_value
return gen_df_getitem_tuple_at_impl(self.dataframe, row, col)

raise TypingError('Operator getitem(). The index must be a row and literal column. Given: {}'.format(idx))
num_idx = isinstance(idx[0], types.Number) and isinstance(self.dataframe.index, types.Array)
str_idx = (isinstance(idx[0], (types.UnicodeType, types.StringLiteral))
and isinstance(self.dataframe.index, StringArrayType))
if isinstance(idx, types.Tuple) and isinstance(idx[1], types.StringLiteral):
if num_idx or str_idx:
row = idx[0]
col = idx[1].literal_value
return gen_df_getitem_tuple_at_impl(self.dataframe, row, col)

raise TypingError('Operator at(). The row parameter type ({}) is different from the index type\
({})'.format(type(idx[0]), type(self.dataframe.index.dtype)))

raise TypingError('Operator at(). The index must be a row and literal column. Given: {}'.format(idx))

if accessor == 'iat':
if isinstance(idx, types.Tuple) and isinstance(idx[1], types.Literal):
Expand Down
13 changes: 13 additions & 0 deletions sdc/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,19 @@ def test_impl(df, n):
for n in n_cases:
self.assertEqual(sdc_func(df, n), test_impl(df, n))

def test_df_at_type(self):
def test_impl(df, n, k):
return df.at[n, "B"]

sdc_func = sdc.jit(test_impl)
idx = ['3', '4', '1', '2', '0']
n_cases = ['2', '3']
df = pd.DataFrame({"A": [3.2, 4.4, 7.0, 3.3, 1.0],
"B": [3, 4, 1, 0, 222],
"C": ['a', 'dd', 'c', '12', 'ddf']}, index=idx)
for n in n_cases:
self.assertEqual(sdc_func(df, n, "B"), test_impl(df, n, "B"))

def test_df_at_value_error(self):
def test_impl(df):
return df.at[1, 'D']
Expand Down