Skip to content

Commit 87232a8

Browse files
gselzerctrueden
authored andcommitted
The return of ND arrays!
They're slower than primitive arrays, but better to have slow arrays than no arrays?
1 parent ace3ed5 commit 87232a8

2 files changed

Lines changed: 45 additions & 19 deletions

File tree

src/scyjava/_java.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,30 @@ def jarray(kind, lengths: Sequence):
523523
if mode == Mode.JEP:
524524
import jep # noqa: F401
525525

526-
# TODO: Support n-d arrays
527-
if len(lengths) > 1:
528-
raise RuntimeError("jep cannot support 2+ dimensional arrays!")
529-
# instantiate the n-dimensional array
530-
arr = jep.jarray(lengths[0], arraytype)
526+
if len(lengths) == 1:
527+
# Fast case: 1-d array (we can use primitives)
528+
arr = jep.jarray(lengths[0], arraytype)
529+
else:
530+
# Slow case: n-d array (we cannot use primitives)
531+
# See https://github.com/ninia/jep/issues/439
532+
kinds = {
533+
"b": jimport("java.lang.Byte"),
534+
"c": jimport("java.lang.Character"),
535+
"d": jimport("java.lang.Double"),
536+
"f": jimport("java.lang.Float"),
537+
"i": jimport("java.lang.Integer"),
538+
"j": jimport("java.lang.Long"),
539+
"s": jimport("java.lang.Short"),
540+
"z": jimport("java.lang.Boolean"),
541+
}
542+
if arraytype in kinds:
543+
arraytype = kinds[arraytype]
544+
kind = arraytype
545+
# build up the array type
546+
for _ in range(len(lengths) - 1):
547+
arraytype = jep.jarray(0, arraytype)
548+
# instantiate the n-dimensional array
549+
arr = jep.jarray(lengths[0], arraytype)
531550

532551
elif mode == Mode.JPYPE:
533552
start_jvm()

tests/test_arrays.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import pytest
32

43
from scyjava import is_jarray, jarray, to_python
54
from scyjava.config import Mode, mode
@@ -47,9 +46,6 @@ def assert_array_conversion_works(jarr, expected):
4746
assert_array_conversion_works(jints, deltas)
4847

4948
def test_jarray2d_to_python(self):
50-
if mode is Mode.JEP:
51-
pytest.skip("Jep doesn't support 2-d arrays")
52-
5349
nums = [
5450
[1.2, 3.4, 5.6],
5551
[7.8, 9.1, 2.3],
@@ -69,7 +65,9 @@ def test_jarray2d_to_python(self):
6965
pdoubles = to_python(jdoubles)
7066

7167
if mode == Mode.JEP:
72-
raise RuntimeError("Not supported")
68+
assert isinstance(pdoubles, list)
69+
assert all(isinstance(v, list) for v in pdoubles)
70+
assert len(nums) == len(pdoubles)
7371

7472
elif mode == Mode.JPYPE:
7573
assert isinstance(pdoubles, np.ndarray)
@@ -81,9 +79,6 @@ def test_jarray2d_to_python(self):
8179
assert nums[i][j] == pdoubles[i][j]
8280

8381
def test_jarray2d_to_python_updates(self):
84-
if mode is Mode.JEP:
85-
pytest.skip("Jep doesn't support 2-d arrays")
86-
8782
nums_init = [
8883
[1.2, 3.4, 5.6],
8984
[7.8, 9.1, 2.3],
@@ -105,9 +100,15 @@ def test_jarray2d_to_python_updates(self):
105100

106101
# assert narr initial state
107102
pdoubles = to_python(jdoubles)
108-
assert isinstance(pdoubles, np.ndarray)
109-
assert np.float64 == pdoubles.dtype
110-
assert (5, 3) == pdoubles.shape
103+
if mode == Mode.JEP:
104+
assert isinstance(pdoubles, list)
105+
assert isinstance(pdoubles[0][0], float)
106+
assert len(pdoubles) == 5
107+
assert len(pdoubles[0]) == 3
108+
elif mode == Mode.JPYPE:
109+
assert isinstance(pdoubles, np.ndarray)
110+
assert np.float64 == pdoubles.dtype
111+
assert (5, 3) == pdoubles.shape
111112
for i in range(len(nums_init)):
112113
for j in range(len(nums_init[i])):
113114
assert nums_init[i][j] == pdoubles[i][j]
@@ -119,9 +120,15 @@ def test_jarray2d_to_python_updates(self):
119120

120121
# assert narr delta state
121122
pdoubles = to_python(jdoubles)
122-
assert isinstance(pdoubles, np.ndarray)
123-
assert np.float64 == pdoubles.dtype
124-
assert (5, 3) == pdoubles.shape
123+
if mode == Mode.JEP:
124+
assert isinstance(pdoubles, list)
125+
assert isinstance(pdoubles[0][0], float)
126+
assert len(pdoubles) == 5
127+
assert len(pdoubles[0]) == 3
128+
elif mode == Mode.JPYPE:
129+
assert isinstance(pdoubles, np.ndarray)
130+
assert np.float64 == pdoubles.dtype
131+
assert (5, 3) == pdoubles.shape
125132
for i in range(len(nums_delta)):
126133
for j in range(len(nums_delta[i])):
127134
assert nums_delta[i][j] == pdoubles[i][j]

0 commit comments

Comments
 (0)