Skip to content

Commit db682a6

Browse files
author
James William Pye
committed
Implement Array.get_element.
1 parent fdf9a3f commit db682a6

2 files changed

Lines changed: 108 additions & 8 deletions

File tree

postgresql/test/test_types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,50 @@ def testNegatives(self):
486486
self.failUnlessEqual(a[-2], 1)
487487
self.failUnlessEqual(a[-1], 2)
488488

489+
def testGetElement(self):
490+
a = Array([1,2,3,4])
491+
self.failUnlessEqual(a.get_element((0,)), 1)
492+
self.failUnlessEqual(a.get_element((1,)), 2)
493+
self.failUnlessEqual(a.get_element((2,)), 3)
494+
self.failUnlessEqual(a.get_element((3,)), 4)
495+
self.failUnlessEqual(a.get_element((-1,)), 4)
496+
self.failUnlessEqual(a.get_element((-2,)), 3)
497+
self.failUnlessEqual(a.get_element((-3,)), 2)
498+
self.failUnlessEqual(a.get_element((-4,)), 1)
499+
self.failUnlessRaises(IndexError, a.get_element, (4,))
500+
a = Array([[1,2],[3,4]])
501+
self.failUnlessEqual(a.get_element((0,0)), 1)
502+
self.failUnlessEqual(a.get_element((0,1,)), 2)
503+
self.failUnlessEqual(a.get_element((1,0,)), 3)
504+
self.failUnlessEqual(a.get_element((1,1,)), 4)
505+
self.failUnlessEqual(a.get_element((-1,-1)), 4)
506+
self.failUnlessEqual(a.get_element((-1,-2,)), 3)
507+
self.failUnlessEqual(a.get_element((-2,-1,)), 2)
508+
self.failUnlessEqual(a.get_element((-2,-2,)), 1)
509+
self.failUnlessRaises(IndexError, a.get_element, (2,0))
510+
self.failUnlessRaises(IndexError, a.get_element, (1,2))
511+
self.failUnlessRaises(IndexError, a.get_element, (0,2))
512+
513+
def testSQLGetElement(self):
514+
a = Array([1,2,3,4])
515+
self.failUnlessEqual(a.sql_get_element((1,)), 1)
516+
self.failUnlessEqual(a.sql_get_element((2,)), 2)
517+
self.failUnlessEqual(a.sql_get_element((3,)), 3)
518+
self.failUnlessEqual(a.sql_get_element((4,)), 4)
519+
self.failUnlessEqual(a.sql_get_element((0,)), None)
520+
self.failUnlessEqual(a.sql_get_element((5,)), None)
521+
self.failUnlessEqual(a.sql_get_element((-1,)), None)
522+
self.failUnlessEqual(a.sql_get_element((-2,)), None)
523+
self.failUnlessEqual(a.sql_get_element((-3,)), None)
524+
self.failUnlessEqual(a.sql_get_element((-4,)), None)
525+
a = Array([[1,2],[3,4]])
526+
self.failUnlessEqual(a.sql_get_element((1,1)), 1)
527+
self.failUnlessEqual(a.sql_get_element((1,2,)), 2)
528+
self.failUnlessEqual(a.sql_get_element((2,1,)), 3)
529+
self.failUnlessEqual(a.sql_get_element((2,2,)), 4)
530+
self.failUnlessEqual(a.sql_get_element((3,1)), None)
531+
self.failUnlessEqual(a.sql_get_element((1,3)), None)
532+
489533
if __name__ == '__main__':
490534
from types import ModuleType
491535
this = ModuleType("this")

postgresql/types/__init__.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class Array(object):
239239
240240
There is also a `dimensions` property, but it is derived from the
241241
`lowerbounds` and `upperbounds` to yield a normalized description of the
242-
structure.
242+
ARRAY's structure.
243243
244244
The Python interfaces, such as __getitem__, are *not* subjected to the
245245
semantics of the lower and upper bounds. Rather, the normalized dimensions
@@ -405,15 +405,71 @@ def elements(self):
405405
def nest(self, seqtype = list):
406406
"""
407407
Transform the array into a nested list.
408+
409+
The `seqtype` keyword can be used to override the type used to represent
410+
the elements of a given axis.
408411
"""
409-
rl = []
410-
typ = self.__class__
411-
for x in self:
412-
if x.__class__ is typ:
413-
rl.append(x.nest())
412+
if self.ndims < 2:
413+
return seqtype(self._elements)
414+
else:
415+
rl = []
416+
for x in self:
417+
rl.append(x.nest(seqtype = seqtype))
418+
return seqtype(rl)
419+
420+
def get_element(self, address,
421+
idxerr = "index {0} at axis {1} is out of range {2}".format
422+
):
423+
"""
424+
Get an element in the array using the given axis sequence.
425+
426+
>>> a=Array([[1,2],[3,4]])
427+
>>> a.get_element((0,0)) == 1
428+
True
429+
>>> a.get_element((1,1)) == 4
430+
True
431+
432+
This is similar to getting items in a nested list::
433+
434+
>>> l=[[1,2],[3,4]]
435+
>>> l[0][0] == 1
436+
True
437+
"""
438+
if not self.dimensions:
439+
raise IndexError("array is empty")
440+
if len(address) != len(self.dimensions):
441+
raise ValueError("given axis sequence is inconsistent with number of dimensions")
442+
443+
# normalize axis specification (-N + DIM), check for IndexErrors, and
444+
# resolve the element's position.
445+
cur = 0
446+
nelements = len(self._elements)
447+
for n, a, dim in zip(range(len(address)), address, self.dimensions):
448+
if a < 0:
449+
a = a + dim
450+
if a < 0:
451+
raise IndexError(idxerr(a, n, dim))
414452
else:
415-
rl.append(x)
416-
return seqtype(rl)
453+
if a >= dim:
454+
raise IndexError(idxerr(a, n, dim))
455+
nelements = nelements // dim
456+
cur += (a * nelements)
457+
return self._elements[cur]
458+
459+
def sql_get_element(self, address):
460+
"""
461+
Like `get_element`, but with SQL indirection semantics. Notably, returns
462+
`None` on IndexError.
463+
"""
464+
try:
465+
a = [a - lb for (a, lb) in zip(address, self.lowerbounds)]
466+
# get_element accepts negatives, so check the converted sequence.
467+
for x in a:
468+
if x < 0:
469+
return None
470+
return self.get_element(a)
471+
except IndexError:
472+
return None
417473

418474
def __repr__(self):
419475
return '%s.%s(%r)' %(

0 commit comments

Comments
 (0)