Skip to content

Commit a19f13b

Browse files
authored
GH-32007: [Python] Support arithmetic on arrays and scalars (#48085)
### Rationale for this change Please see #32007, currently, neither arrays nor scalars support Python-native arithmetic operations, such as `array + array`, it has to be done via `pyarrow.compute` API. This PR strives to fix this with custom dunder methods. ### What changes are included in this PR? Implemented dunder methods ### Are these changes tested? Yes ### Are there any user-facing changes? Possibility to use Python operators directly instead of calling the `pyarrow.compute` API. * GitHub Issue: #32007 Authored-by: Bogdan Romenskii <rmnsk@seznam.cz> Signed-off-by: AlenkaF <frim.alenka@gmail.com>
1 parent 560ef02 commit a19f13b

5 files changed

Lines changed: 305 additions & 0 deletions

File tree

docs/source/python/compute.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,46 @@ For example, the "numpy_gcd" function that we've been using as an example above
512512
function to use in a projection. A "cumulative sum" function would not be a valid function
513513
since the result of each input row depends on the rows that came before. A "drop nulls"
514514
function would also be invalid because it doesn't emit a value for some rows.
515+
516+
517+
Standard Python Operators
518+
=========================
519+
520+
PyArrow supports standard Python operators for element-wise operations for arrays and scalars.
521+
Currently, the support is limited to some of the standard compute functions, i.e.
522+
arithmetic (``+``, ``-``, ``/``, ``%``, ``**``),
523+
bitwise (``&``, ``|``, ``^``, ``>>``, ``<<``) and others.
524+
525+
The aforementioned operators use checked version of underlying kernels wherever possible
526+
and have the same respective constraints, e.g. you cannot add two arrays of strings.
527+
528+
You can use the operators as following:
529+
530+
.. code-block:: python
531+
532+
>>> import pyarrow as pa
533+
>>> arr = pa.array([-1, 2, -3])
534+
>>> val = pa.scalar(42.7)
535+
>>> arr + val
536+
<pyarrow.lib.DoubleArray object at ...>
537+
[
538+
41.7,
539+
44.7,
540+
39.7
541+
]
542+
543+
>>> val ** arr
544+
<pyarrow.lib.DoubleArray object at ...>
545+
[
546+
0.023419203747072598,
547+
1823.2900000000002,
548+
0.000012844475506953143
549+
]
550+
551+
>>> arr << 2
552+
<pyarrow.lib.Int64Array object at ...>
553+
[
554+
-4,
555+
8,
556+
-12
557+
]

python/pyarrow/array.pxi

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
192192
pa.int16() even if pa.int8() was passed to the function. Note that an
193193
explicit index type will not be demoted even if it is wider than required.
194194
195+
This class supports Python's standard operators
196+
for element-wise operations, i.e. arithmetic (`+`, `-`, `/`, `%`, `**`),
197+
bitwise (`&`, `|`, `^`, `>>`, `<<`) and others.
198+
They can be used directly instead of calling underlying
199+
`pyarrow.compute` functions explicitly.
200+
195201
Examples
196202
--------
197203
>>> import pandas as pd
@@ -229,6 +235,25 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
229235
>>> arr = pa.array(range(1024), type=pa.dictionary(pa.int8(), pa.int64()))
230236
>>> arr.type.index_type
231237
DataType(int16)
238+
239+
>>> arr1 = pa.array([1, 2, 3], type=pa.int8())
240+
>>> arr2 = pa.array([4, 5, 6], type=pa.int8())
241+
>>> arr1 + arr2
242+
<pyarrow.lib.Int8Array object at ...>
243+
[
244+
5,
245+
7,
246+
9
247+
]
248+
249+
>>> val = pa.scalar(42)
250+
>>> val - arr1
251+
<pyarrow.lib.Int64Array object at ...>
252+
[
253+
41,
254+
40,
255+
39
256+
]
232257
"""
233258
cdef:
234259
CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
@@ -2259,6 +2284,54 @@ cdef class Array(_PandasConvertible):
22592284
stat.init(sp_stat)
22602285
return stat
22612286

2287+
def __abs__(self):
2288+
self._assert_cpu()
2289+
return _pc().call_function('abs_checked', [self])
2290+
2291+
def __add__(self, object other):
2292+
self._assert_cpu()
2293+
return _pc().call_function('add_checked', [self, other])
2294+
2295+
def __truediv__(self, object other):
2296+
self._assert_cpu()
2297+
return _pc().call_function('divide_checked', [self, other])
2298+
2299+
def __mul__(self, object other):
2300+
self._assert_cpu()
2301+
return _pc().call_function('multiply_checked', [self, other])
2302+
2303+
def __neg__(self):
2304+
self._assert_cpu()
2305+
return _pc().call_function('negate_checked', [self])
2306+
2307+
def __pow__(self, object other):
2308+
self._assert_cpu()
2309+
return _pc().call_function('power_checked', [self, other])
2310+
2311+
def __sub__(self, object other):
2312+
self._assert_cpu()
2313+
return _pc().call_function('subtract_checked', [self, other])
2314+
2315+
def __and__(self, object other):
2316+
self._assert_cpu()
2317+
return _pc().call_function('bit_wise_and', [self, other])
2318+
2319+
def __or__(self, object other):
2320+
self._assert_cpu()
2321+
return _pc().call_function('bit_wise_or', [self, other])
2322+
2323+
def __xor__(self, object other):
2324+
self._assert_cpu()
2325+
return _pc().call_function('bit_wise_xor', [self, other])
2326+
2327+
def __lshift__(self, object other):
2328+
self._assert_cpu()
2329+
return _pc().call_function('shift_left_checked', [self, other])
2330+
2331+
def __rshift__(self, object other):
2332+
self._assert_cpu()
2333+
return _pc().call_function('shift_right_checked', [self, other])
2334+
22622335

22632336
cdef _array_like_to_pandas(obj, options, types_mapper):
22642337
cdef:

python/pyarrow/scalar.pxi

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,33 @@ from collections.abc import Sequence, Mapping
2424
cdef class Scalar(_Weakrefable):
2525
"""
2626
The base class for scalars.
27+
28+
Notes
29+
-----
30+
This class supports Python's standard operators
31+
for element-wise operations, i.e. arithmetic (`+`, `-`, `/`, `%`, `**`),
32+
bitwise (`&`, `|`, `^`, `>>`, `<<`) and others.
33+
They can be used directly instead of calling underlying
34+
`pyarrow.compute` functions explicitly.
35+
36+
Examples
37+
--------
38+
>>> import pyarrow as pa
39+
>>> pa.scalar(42) + pa.scalar(17)
40+
<pyarrow.Int64Scalar: 59>
41+
42+
>>> pa.scalar(6) ** 3
43+
<pyarrow.Int64Scalar: 216>
44+
45+
>>> arr = pa.array([1, 2, 3], type=pa.int8())
46+
>>> val = pa.scalar(42)
47+
>>> val - arr
48+
<pyarrow.lib.Int64Array object at ...>
49+
[
50+
41,
51+
40,
52+
39
53+
]
2754
"""
2855

2956
def __init__(self):
@@ -168,6 +195,42 @@ cdef class Scalar(_Weakrefable):
168195
"""
169196
raise NotImplementedError()
170197

198+
def __abs__(self):
199+
return _pc().call_function('abs_checked', [self])
200+
201+
def __add__(self, object other):
202+
return _pc().call_function('add_checked', [self, other])
203+
204+
def __truediv__(self, object other):
205+
return _pc().call_function('divide_checked', [self, other])
206+
207+
def __mul__(self, object other):
208+
return _pc().call_function('multiply_checked', [self, other])
209+
210+
def __neg__(self):
211+
return _pc().call_function('negate_checked', [self])
212+
213+
def __pow__(self, object other):
214+
return _pc().call_function('power_checked', [self, other])
215+
216+
def __sub__(self, object other):
217+
return _pc().call_function('subtract_checked', [self, other])
218+
219+
def __and__(self, object other):
220+
return _pc().call_function('bit_wise_and', [self, other])
221+
222+
def __or__(self, object other):
223+
return _pc().call_function('bit_wise_or', [self, other])
224+
225+
def __xor__(self, object other):
226+
return _pc().call_function('bit_wise_xor', [self, other])
227+
228+
def __lshift__(self, object other):
229+
return _pc().call_function('shift_left_checked', [self, other])
230+
231+
def __rshift__(self, object other):
232+
return _pc().call_function('shift_right_checked', [self, other])
233+
171234

172235
_NULL = NA = None
173236

python/pyarrow/tests/test_array.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import pyarrow as pa
3636
import pyarrow.tests.strategies as past
3737
from pyarrow.vendored.version import Version
38+
import pyarrow.compute as pc
3839

3940

4041
@pytest.mark.processes
@@ -4398,3 +4399,72 @@ def test_non_cpu_array():
43984399
arr.tolist()
43994400
with pytest.raises(NotImplementedError):
44004401
arr.validate(full=True)
4402+
4403+
4404+
def test_arithmetic_dunders():
4405+
# GH-32007
4406+
arr1 = pa.array([-1.1, 2.2, -3.3])
4407+
arr2 = pa.array([2.2, 4.4, 5.5])
4408+
4409+
assert (arr1 + arr2).equals(pc.add_checked(arr1, arr2))
4410+
assert (arr2 / arr1).equals(pc.divide_checked(arr2, arr1))
4411+
assert (arr1 * arr2).equals(pc.multiply_checked(arr1, arr2))
4412+
assert (-arr1).equals(pc.negate_checked(arr1))
4413+
assert (arr1 ** 2).equals(pc.power_checked(arr1, 2))
4414+
assert (arr1 - arr2).equals(pc.subtract_checked(arr1, arr2))
4415+
4416+
4417+
def test_bitwise_dunders():
4418+
# GH-32007
4419+
arr1 = pa.array([-1, 2, -3])
4420+
arr2 = pa.array([2, 4, 5])
4421+
4422+
assert (arr1 & arr2).equals(pc.bit_wise_and(arr1, arr2))
4423+
assert (arr1 | arr2).equals(pc.bit_wise_or(arr1, arr2))
4424+
assert (arr1 ^ arr2).equals(pc.bit_wise_xor(arr1, arr2))
4425+
assert (arr1 << arr2).equals(pc.shift_left_checked(arr1, arr2))
4426+
assert (arr1 >> arr2).equals(pc.shift_right_checked(arr1, arr2))
4427+
4428+
4429+
def test_dunders_unmatching_types():
4430+
# GH-32007
4431+
error_match = r"Function '\w+' has no kernel matching input types"
4432+
string_arr = pa.array(["a", "b", "c"])
4433+
nested_arr = pa.array([{"x": 1, "y": True}, {"z": 3.4, "x": 4}])
4434+
double_arr = pa.array([1.0, 2.0, 3.0])
4435+
4436+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
4437+
string_arr + nested_arr
4438+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
4439+
string_arr - double_arr
4440+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
4441+
double_arr * nested_arr
4442+
4443+
4444+
def test_dunders_mixed_types():
4445+
# GH-32007
4446+
arr = pa.array([11.0, 17.0, 23.0])
4447+
val = pa.scalar(3)
4448+
4449+
assert (arr + val).equals(pc.add_checked(arr, val))
4450+
assert (arr - val).equals(pc.subtract_checked(arr, val))
4451+
assert (arr / val).equals(pc.divide_checked(arr, val))
4452+
assert (arr * val).equals(pc.multiply_checked(arr, val))
4453+
assert (arr ** val).equals(pc.power_checked(arr, val))
4454+
4455+
4456+
def test_dunders_checked_overflow():
4457+
# GH-32007
4458+
arr = pa.array([127, -128], type=pa.int8())
4459+
error_match = "overflow"
4460+
4461+
with pytest.raises(pa.ArrowInvalid, match=error_match):
4462+
arr + arr
4463+
with pytest.raises(pa.ArrowInvalid, match=error_match):
4464+
arr * arr
4465+
with pytest.raises(pa.ArrowInvalid, match=error_match):
4466+
arr - (-arr)
4467+
with pytest.raises(pa.ArrowInvalid, match=error_match):
4468+
arr ** pa.scalar(2, type=pa.int8())
4469+
with pytest.raises(pa.ArrowInvalid, match=error_match):
4470+
arr / (-arr)

python/pyarrow/tests/test_scalars.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,59 @@ def test_map_scalar_with_empty_values():
995995
s = pa.scalar(v, type=map_type)
996996

997997
assert s.as_py(maps_as_pydicts="strict") == v
998+
999+
1000+
def test_arithmetic_dunders():
1001+
# GH-32007
1002+
scl1 = pa.scalar(42)
1003+
scl2 = pa.scalar(-17)
1004+
1005+
assert (scl1 + scl2).equals(pc.add_checked(scl1, scl2))
1006+
assert (scl2 / scl1).equals(pc.divide_checked(scl2, scl1))
1007+
assert (scl1 * scl2).equals(pc.multiply_checked(scl1, scl2))
1008+
assert (-scl1).equals(pc.negate_checked(scl1))
1009+
assert (scl1 ** 2).equals(pc.power_checked(scl1, 2))
1010+
assert (scl1 - scl2).equals(pc.subtract_checked(scl1, scl2))
1011+
1012+
1013+
def test_bitwise_dunders():
1014+
# GH-32007
1015+
scl1 = pa.scalar(42)
1016+
scl2 = pa.scalar(-17)
1017+
1018+
assert (scl1 & scl2).equals(pc.bit_wise_and(scl1, scl2))
1019+
assert (scl1 | scl2).equals(pc.bit_wise_or(scl1, scl2))
1020+
assert (scl1 ^ scl2).equals(pc.bit_wise_xor(scl1, scl2))
1021+
assert (scl2 << scl1).equals(pc.shift_left_checked(scl2, scl1))
1022+
assert (scl2 >> scl1).equals(pc.shift_right_checked(scl2, scl1))
1023+
1024+
1025+
def test_dunders_unmatching_types():
1026+
# GH-32007
1027+
error_match = r"Function '\w+' has no kernel matching input types"
1028+
string_scl = pa.scalar("abc")
1029+
double_scl = pa.scalar(1.23)
1030+
1031+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
1032+
string_scl + double_scl
1033+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
1034+
string_scl - double_scl
1035+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
1036+
string_scl / double_scl
1037+
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
1038+
string_scl * double_scl
1039+
1040+
1041+
def test_dunders_checked_overflow():
1042+
# GH-32007
1043+
error_match = "overflow"
1044+
scl = pa.scalar(127, type=pa.int8())
1045+
1046+
with pytest.raises(pa.ArrowInvalid, match=error_match):
1047+
scl + scl
1048+
with pytest.raises(pa.ArrowInvalid, match=error_match):
1049+
scl - (-scl)
1050+
with pytest.raises(pa.ArrowInvalid, match=error_match):
1051+
scl ** scl
1052+
with pytest.raises(pa.ArrowInvalid, match=error_match):
1053+
scl * scl

0 commit comments

Comments
 (0)