Skip to content

Commit bf38bee

Browse files
author
raymond.hettinger
committed
Promote compress() from a recipe to being a regular itertool.
git-svn-id: http://svn.python.org/projects/python/trunk@68941 6015fed2-1504-0410-9fe1-9d1591cc4771
1 parent a475bd1 commit bf38bee

5 files changed

Lines changed: 206 additions & 13 deletions

File tree

Doc/library/collections.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ counts less than one::
298298
Section 4.6.3, Exercise 19*\.
299299

300300
* To enumerate all distinct multisets of a given size over a given set of
301-
elements, see the :func:`combinations_with_replacement` function in the
301+
elements, see :func:`combinations_with_replacement` in the
302302
:ref:`itertools-recipes` for itertools::
303303

304304
map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC

Doc/library/itertools.rst

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,20 @@ loops that truncate the stream.
139139

140140
.. versionadded:: 2.6
141141

142+
.. function:: compress(data, selectors)
143+
144+
Make an iterator that filters elements from *data* returning only those that
145+
have a corresponding element in *selectors* that evaluates to ``True``.
146+
Stops when either the *data* or *selectors* iterables have been exhausted.
147+
Equivalent to::
148+
149+
def compress(data, selectors):
150+
# compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
151+
return (d for d, s in izip(data, selectors) if s)
152+
153+
.. versionadded:: 2.7
154+
155+
142156
.. function:: count([n])
143157

144158
Make an iterator that returns consecutive integers starting with *n*. If not
@@ -679,10 +693,6 @@ which incur interpreter overhead.
679693
for n in xrange(2**len(pairs)):
680694
yield set(x for m, x in pairs if m&n)
681695

682-
def compress(data, selectors):
683-
"compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
684-
return (d for d, s in izip(data, selectors) if s)
685-
686696
def combinations_with_replacement(iterable, r):
687697
"combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
688698
# number items returned: (n+r-1)! / r! / (n-1)!

Lib/test/test_itertools.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ def permutations2(iterable, r=None):
191191
self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
192192
self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
193193

194+
def test_compress(self):
195+
self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
196+
self.assertEqual(list(compress('ABCDEF', [0,0,0,0,0,0])), list(''))
197+
self.assertEqual(list(compress('ABCDEF', [1,1,1,1,1,1])), list('ABCDEF'))
198+
self.assertEqual(list(compress('ABCDEF', [1,0,1])), list('AC'))
199+
self.assertEqual(list(compress('ABC', [0,1,1,1,1,1])), list('BC'))
200+
n = 10000
201+
data = chain.from_iterable(repeat(range(6), n))
202+
selectors = chain.from_iterable(repeat((0, 1)))
203+
self.assertEqual(list(compress(data, selectors)), [1,3,5] * n)
204+
self.assertRaises(TypeError, compress, None, range(6)) # 1st arg not iterable
205+
self.assertRaises(TypeError, compress, range(6), None) # 2nd arg not iterable
206+
self.assertRaises(TypeError, compress, range(6)) # too few args
207+
self.assertRaises(TypeError, compress, range(6), None) # too many args
208+
194209
def test_count(self):
195210
self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
196211
self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
@@ -701,6 +716,9 @@ def test_combinations(self):
701716
self.assertEqual(list(combinations(range(4), 3)),
702717
[(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
703718

719+
def test_compress(self):
720+
self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
721+
704722
def test_count(self):
705723
self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14])
706724

@@ -781,6 +799,10 @@ def test_combinations(self):
781799
a = []
782800
self.makecycle(combinations([1,2,a,3], 3), a)
783801

802+
def test_compress(self):
803+
a = []
804+
self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
805+
784806
def test_cycle(self):
785807
a = []
786808
self.makecycle(cycle([a]*2), a)
@@ -934,6 +956,15 @@ def test_chain(self):
934956
self.assertRaises(TypeError, list, chain(N(s)))
935957
self.assertRaises(ZeroDivisionError, list, chain(E(s)))
936958

959+
def test_compress(self):
960+
for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
961+
n = len(s)
962+
for g in (G, I, Ig, S, L, R):
963+
self.assertEqual(list(compress(g(s), repeat(1))), list(g(s)))
964+
self.assertRaises(TypeError, compress, X(s), repeat(1))
965+
self.assertRaises(TypeError, list, compress(N(s), repeat(1)))
966+
self.assertRaises(ZeroDivisionError, list, compress(E(s), repeat(1)))
967+
937968
def test_product(self):
938969
for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
939970
self.assertRaises(TypeError, product, X(s))
@@ -1125,7 +1156,7 @@ class SubclassWithKwargsTest(unittest.TestCase):
11251156
def test_keywords_in_subclass(self):
11261157
# count is not subclassable...
11271158
for cls in (repeat, izip, ifilter, ifilterfalse, chain, imap,
1128-
starmap, islice, takewhile, dropwhile, cycle):
1159+
starmap, islice, takewhile, dropwhile, cycle, compress):
11291160
class Subclass(cls):
11301161
def __init__(self, newarg=None, *args):
11311162
cls.__init__(self, *args)
@@ -1262,10 +1293,6 @@ def __init__(self, newarg=None, *args):
12621293
... for n in xrange(2**len(pairs)):
12631294
... yield set(x for m, x in pairs if m&n)
12641295
1265-
>>> def compress(data, selectors):
1266-
... "compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
1267-
... return (d for d, s in izip(data, selectors) if s)
1268-
12691296
>>> def combinations_with_replacement(iterable, r):
12701297
... "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
12711298
... pool = tuple(iterable)
@@ -1361,9 +1388,6 @@ def __init__(self, newarg=None, *args):
13611388
>>> map(sorted, powerset('ab'))
13621389
[[], ['a'], ['b'], ['a', 'b']]
13631390
1364-
>>> list(compress('abcdef', [1,0,1,0,1,1]))
1365-
['a', 'c', 'e', 'f']
1366-
13671391
>>> list(combinations_with_replacement('abc', 2))
13681392
[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
13691393

Misc/NEWS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ Library
147147

148148
- Issue #4863: distutils.mwerkscompiler has been removed.
149149

150+
- Added a new function: itertools.compress().
151+
150152
- Fix and properly document the multiprocessing module's logging
151153
support, expose the internal levels and provide proper usage
152154
examples.

Modules/itertoolsmodule.c

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2506,6 +2506,162 @@ static PyTypeObject permutations_type = {
25062506
};
25072507

25082508

2509+
/* compress object ************************************************************/
2510+
2511+
/* Equivalent to:
2512+
2513+
def compress(data, selectors):
2514+
"compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
2515+
return (d for d, s in izip(data, selectors) if s)
2516+
*/
2517+
2518+
typedef struct {
2519+
PyObject_HEAD
2520+
PyObject *data;
2521+
PyObject *selectors;
2522+
} compressobject;
2523+
2524+
static PyTypeObject compress_type;
2525+
2526+
static PyObject *
2527+
compress_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
2528+
{
2529+
PyObject *seq1, *seq2;
2530+
PyObject *data=NULL, *selectors=NULL;
2531+
compressobject *lz;
2532+
2533+
if (type == &compress_type && !_PyArg_NoKeywords("compress()", kwds))
2534+
return NULL;
2535+
2536+
if (!PyArg_UnpackTuple(args, "compress", 2, 2, &seq1, &seq2))
2537+
return NULL;
2538+
2539+
data = PyObject_GetIter(seq1);
2540+
if (data == NULL)
2541+
goto fail;
2542+
selectors = PyObject_GetIter(seq2);
2543+
if (selectors == NULL)
2544+
goto fail;
2545+
2546+
/* create compressobject structure */
2547+
lz = (compressobject *)type->tp_alloc(type, 0);
2548+
if (lz == NULL)
2549+
goto fail;
2550+
lz->data = data;
2551+
lz->selectors = selectors;
2552+
return (PyObject *)lz;
2553+
2554+
fail:
2555+
Py_XDECREF(data);
2556+
Py_XDECREF(selectors);
2557+
return NULL;
2558+
}
2559+
2560+
static void
2561+
compress_dealloc(compressobject *lz)
2562+
{
2563+
PyObject_GC_UnTrack(lz);
2564+
Py_XDECREF(lz->data);
2565+
Py_XDECREF(lz->selectors);
2566+
Py_TYPE(lz)->tp_free(lz);
2567+
}
2568+
2569+
static int
2570+
compress_traverse(compressobject *lz, visitproc visit, void *arg)
2571+
{
2572+
Py_VISIT(lz->data);
2573+
Py_VISIT(lz->selectors);
2574+
return 0;
2575+
}
2576+
2577+
static PyObject *
2578+
compress_next(compressobject *lz)
2579+
{
2580+
PyObject *data = lz->data, *selectors = lz->selectors;
2581+
PyObject *datum, *selector;
2582+
PyObject *(*datanext)(PyObject *) = *Py_TYPE(data)->tp_iternext;
2583+
PyObject *(*selectornext)(PyObject *) = *Py_TYPE(selectors)->tp_iternext;
2584+
int ok;
2585+
2586+
while (1) {
2587+
/* Steps: get datum, get selector, evaluate selector.
2588+
Order is important (to match the pure python version
2589+
in terms of which input gets a chance to raise an
2590+
exception first).
2591+
*/
2592+
2593+
datum = datanext(data);
2594+
if (datum == NULL)
2595+
return NULL;
2596+
2597+
selector = selectornext(selectors);
2598+
if (selector == NULL) {
2599+
Py_DECREF(datum);
2600+
return NULL;
2601+
}
2602+
2603+
ok = PyObject_IsTrue(selector);
2604+
Py_DECREF(selector);
2605+
if (ok == 1)
2606+
return datum;
2607+
Py_DECREF(datum);
2608+
if (ok == -1)
2609+
return NULL;
2610+
}
2611+
}
2612+
2613+
PyDoc_STRVAR(compress_doc,
2614+
"compress(data sequence, selector sequence) --> iterator over selected data\n\
2615+
\n\
2616+
Return data elements corresponding to true selector elements.\n\
2617+
Forms a shorter iterator from selected data elements using the\n\
2618+
selectors to choose the data elements.");
2619+
2620+
static PyTypeObject compress_type = {
2621+
PyVarObject_HEAD_INIT(NULL, 0)
2622+
"itertools.compress", /* tp_name */
2623+
sizeof(compressobject), /* tp_basicsize */
2624+
0, /* tp_itemsize */
2625+
/* methods */
2626+
(destructor)compress_dealloc, /* tp_dealloc */
2627+
0, /* tp_print */
2628+
0, /* tp_getattr */
2629+
0, /* tp_setattr */
2630+
0, /* tp_compare */
2631+
0, /* tp_repr */
2632+
0, /* tp_as_number */
2633+
0, /* tp_as_sequence */
2634+
0, /* tp_as_mapping */
2635+
0, /* tp_hash */
2636+
0, /* tp_call */
2637+
0, /* tp_str */
2638+
PyObject_GenericGetAttr, /* tp_getattro */
2639+
0, /* tp_setattro */
2640+
0, /* tp_as_buffer */
2641+
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
2642+
Py_TPFLAGS_BASETYPE, /* tp_flags */
2643+
compress_doc, /* tp_doc */
2644+
(traverseproc)compress_traverse, /* tp_traverse */
2645+
0, /* tp_clear */
2646+
0, /* tp_richcompare */
2647+
0, /* tp_weaklistoffset */
2648+
PyObject_SelfIter, /* tp_iter */
2649+
(iternextfunc)compress_next, /* tp_iternext */
2650+
0, /* tp_methods */
2651+
0, /* tp_members */
2652+
0, /* tp_getset */
2653+
0, /* tp_base */
2654+
0, /* tp_dict */
2655+
0, /* tp_descr_get */
2656+
0, /* tp_descr_set */
2657+
0, /* tp_dictoffset */
2658+
0, /* tp_init */
2659+
0, /* tp_alloc */
2660+
compress_new, /* tp_new */
2661+
PyObject_GC_Del, /* tp_free */
2662+
};
2663+
2664+
25092665
/* ifilter object ************************************************************/
25102666

25112667
typedef struct {
@@ -3552,6 +3708,7 @@ inititertools(void)
35523708
&starmap_type,
35533709
&imap_type,
35543710
&chain_type,
3711+
&compress_type,
35553712
&ifilter_type,
35563713
&ifilterfalse_type,
35573714
&count_type,

0 commit comments

Comments
 (0)