@@ -200,19 +200,62 @@ def test_math_functions():
200200
201201
202202def test_array_functions ():
203- data = [[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 ], [6.0 ]]
203+ data = [[1.0 , 2.0 , 3.0 , 3.0 ], [4.0 , 5.0 , 3 .0 ], [6.0 ]]
204204 ctx = SessionContext ()
205205 batch = pa .RecordBatch .from_arrays (
206206 [np .array (data , dtype = object )], names = ["arr" ]
207207 )
208208 df = ctx .create_dataframe ([[batch ]])
209209
210+ def py_indexof (arr , v ):
211+ try :
212+ return arr .index (v ) + 1
213+ except ValueError :
214+ return np .nan
215+
216+ def py_arr_remove (arr , v , n = None ):
217+ new_arr = arr [:]
218+ found = 0
219+ while found != n :
220+ try :
221+ new_arr .remove (v )
222+ found += 1
223+ except ValueError :
224+ break
225+
226+ return new_arr
227+
228+ def py_arr_replace (arr , from_ , to , n = None ):
229+ new_arr = arr [:]
230+ found = 0
231+ while found != n :
232+ try :
233+ idx = new_arr .index (from_ )
234+ new_arr [idx ] = to
235+ found += 1
236+ except ValueError :
237+ break
238+
239+ return new_arr
240+
210241 col = column ("arr" )
211242 test_items = [
212243 [
213244 f .array_append (col , literal (99.0 )),
214245 lambda : [np .append (arr , 99.0 ) for arr in data ],
215246 ],
247+ [
248+ f .array_push_back (col , literal (99.0 )),
249+ lambda : [np .append (arr , 99.0 ) for arr in data ],
250+ ],
251+ [
252+ f .list_append (col , literal (99.0 )),
253+ lambda : [np .append (arr , 99.0 ) for arr in data ],
254+ ],
255+ [
256+ f .list_push_back (col , literal (99.0 )),
257+ lambda : [np .append (arr , 99.0 ) for arr in data ],
258+ ],
216259 [
217260 f .array_concat (col , col ),
218261 lambda : [np .concatenate ([arr , arr ]) for arr in data ],
@@ -253,12 +296,174 @@ def test_array_functions():
253296 f .list_length (col ),
254297 lambda : [len (r ) for r in data ],
255298 ],
299+ [
300+ f .array_has (col , literal (1.0 )),
301+ lambda : [1.0 in r for r in data ],
302+ ],
303+ [
304+ f .array_has_all (
305+ col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
306+ ),
307+ lambda : [np .all ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
308+ ],
309+ [
310+ f .array_has_any (
311+ col , f .make_array (* [literal (v ) for v in [1.0 , 3.0 , 5.0 ]])
312+ ),
313+ lambda : [np .any ([v in r for v in [1.0 , 3.0 , 5.0 ]]) for r in data ],
314+ ],
315+ [
316+ f .array_position (col , literal (1.0 )),
317+ lambda : [py_indexof (r , 1.0 ) for r in data ],
318+ ],
319+ [
320+ f .array_indexof (col , literal (1.0 )),
321+ lambda : [py_indexof (r , 1.0 ) for r in data ],
322+ ],
323+ [
324+ f .list_position (col , literal (1.0 )),
325+ lambda : [py_indexof (r , 1.0 ) for r in data ],
326+ ],
327+ [
328+ f .list_indexof (col , literal (1.0 )),
329+ lambda : [py_indexof (r , 1.0 ) for r in data ],
330+ ],
331+ [
332+ f .array_positions (col , literal (1.0 )),
333+ lambda : [
334+ [i + 1 for i , _v in enumerate (r ) if _v == 1.0 ] for r in data
335+ ],
336+ ],
337+ [
338+ f .list_positions (col , literal (1.0 )),
339+ lambda : [
340+ [i + 1 for i , _v in enumerate (r ) if _v == 1.0 ] for r in data
341+ ],
342+ ],
343+ [
344+ f .array_ndims (col ),
345+ lambda : [np .array (r ).ndim for r in data ],
346+ ],
347+ [
348+ f .list_ndims (col ),
349+ lambda : [np .array (r ).ndim for r in data ],
350+ ],
351+ [
352+ f .array_prepend (literal (99.0 ), col ),
353+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
354+ ],
355+ [
356+ f .array_push_front (literal (99.0 ), col ),
357+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
358+ ],
359+ [
360+ f .list_prepend (literal (99.0 ), col ),
361+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
362+ ],
363+ [
364+ f .list_push_front (literal (99.0 ), col ),
365+ lambda : [np .insert (arr , 0 , 99.0 ) for arr in data ],
366+ ],
367+ [
368+ f .array_pop_back (col ),
369+ lambda : [arr [:- 1 ] for arr in data ],
370+ ],
371+ [
372+ f .array_pop_front (col ),
373+ lambda : [arr [1 :] for arr in data ],
374+ ],
375+ [
376+ f .array_remove (col , literal (3.0 )),
377+ lambda : [py_arr_remove (arr , 3.0 , 1 ) for arr in data ],
378+ ],
379+ [
380+ f .list_remove (col , literal (3.0 )),
381+ lambda : [py_arr_remove (arr , 3.0 , 1 ) for arr in data ],
382+ ],
383+ [
384+ f .array_remove_n (col , literal (3.0 ), literal (2 )),
385+ lambda : [py_arr_remove (arr , 3.0 , 2 ) for arr in data ],
386+ ],
387+ [
388+ f .list_remove_n (col , literal (3.0 ), literal (2 )),
389+ lambda : [py_arr_remove (arr , 3.0 , 2 ) for arr in data ],
390+ ],
391+ [
392+ f .array_remove_all (col , literal (3.0 )),
393+ lambda : [py_arr_remove (arr , 3.0 ) for arr in data ],
394+ ],
395+ [
396+ f .list_remove_all (col , literal (3.0 )),
397+ lambda : [py_arr_remove (arr , 3.0 ) for arr in data ],
398+ ],
399+ [
400+ f .array_repeat (col , literal (2 )),
401+ lambda : [[arr ] * 2 for arr in data ],
402+ ],
403+ [
404+ f .array_replace (col , literal (3.0 ), literal (4.0 )),
405+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 1 ) for arr in data ],
406+ ],
407+ [
408+ f .list_replace (col , literal (3.0 ), literal (4.0 )),
409+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 1 ) for arr in data ],
410+ ],
411+ [
412+ f .array_replace_n (col , literal (3.0 ), literal (4.0 ), literal (1 )),
413+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 1 ) for arr in data ],
414+ ],
415+ [
416+ f .list_replace_n (col , literal (3.0 ), literal (4.0 ), literal (2 )),
417+ lambda : [py_arr_replace (arr , 3.0 , 4.0 , 2 ) for arr in data ],
418+ ],
419+ [
420+ f .array_replace_all (col , literal (3.0 ), literal (4.0 )),
421+ lambda : [py_arr_replace (arr , 3.0 , 4.0 ) for arr in data ],
422+ ],
423+ [
424+ f .list_replace_all (col , literal (3.0 ), literal (4.0 )),
425+ lambda : [py_arr_replace (arr , 3.0 , 4.0 ) for arr in data ],
426+ ],
427+ [
428+ f .array_slice (col , literal (2 ), literal (4 )),
429+ lambda : [arr [1 :4 ] for arr in data ],
430+ ],
431+ [
432+ f .list_slice (col , literal (- 1 ), literal (2 )),
433+ lambda : [arr [- 1 :2 ] for arr in data ],
434+ ],
256435 ]
257436
258437 for stmt , py_expr in test_items :
259- query_result = df .select (stmt ).collect ()[0 ].column (0 ).tolist ()
438+ query_result = df .select (stmt ).collect ()[0 ].column (0 )
439+ for a , b in zip (query_result , py_expr ()):
440+ np .testing .assert_array_almost_equal (
441+ np .array (a .as_py (), dtype = float ), np .array (b , dtype = float )
442+ )
443+
444+ obj_test_items = [
445+ [
446+ f .array_to_string (col , literal ("," )),
447+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
448+ ],
449+ [
450+ f .array_join (col , literal ("," )),
451+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
452+ ],
453+ [
454+ f .list_to_string (col , literal ("," )),
455+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
456+ ],
457+ [
458+ f .list_join (col , literal ("," )),
459+ lambda : ["," .join ([str (int (v )) for v in r ]) for r in data ],
460+ ],
461+ ]
462+
463+ for stmt , py_expr in obj_test_items :
464+ query_result = np .array (df .select (stmt ).collect ()[0 ].column (0 ))
260465 for a , b in zip (query_result , py_expr ()):
261- np . testing . assert_array_almost_equal ( a , b )
466+ assert a == b
262467
263468
264469def test_string_functions (df ):
0 commit comments