1010#include "pycore_modsupport.h" // _PyArg_NoKwnames()
1111#include "pycore_object.h" // _PyObject_GC_TRACK(), _PyDebugAllocatorStats()
1212#include "pycore_tuple.h" // _PyTuple_FromArray()
13+ #include "pycore_setobject.h" // _PySet_NextEntry()
1314#include <stddef.h>
1415
1516/*[clinic input]
@@ -994,35 +995,36 @@ PyList_SetSlice(PyObject *a, Py_ssize_t ilow, Py_ssize_t ihigh, PyObject *v)
994995 return list_ass_slice ((PyListObject * )a , ilow , ihigh , v );
995996}
996997
997- static PyObject *
998+ static int
998999list_inplace_repeat_lock_held (PyListObject * self , Py_ssize_t n )
9991000{
10001001 Py_ssize_t input_size = PyList_GET_SIZE (self );
10011002 if (input_size == 0 || n == 1 ) {
1002- return Py_NewRef ( self ) ;
1003+ return 0 ;
10031004 }
10041005
10051006 if (n < 1 ) {
10061007 list_clear (self );
1007- return Py_NewRef ( self ) ;
1008+ return 0 ;
10081009 }
10091010
10101011 if (input_size > PY_SSIZE_T_MAX / n ) {
1011- return PyErr_NoMemory ();
1012+ PyErr_NoMemory ();
1013+ return -1 ;
10121014 }
10131015 Py_ssize_t output_size = input_size * n ;
10141016
1015- if (list_resize (self , output_size ) < 0 )
1016- return NULL ;
1017+ if (list_resize (self , output_size ) < 0 ) {
1018+ return -1 ;
1019+ }
10171020
10181021 PyObject * * items = self -> ob_item ;
10191022 for (Py_ssize_t j = 0 ; j < input_size ; j ++ ) {
10201023 _Py_RefcntAdd (items [j ], n - 1 );
10211024 }
10221025 _Py_memory_repeat ((char * )items , sizeof (PyObject * )* output_size ,
10231026 sizeof (PyObject * )* input_size );
1024-
1025- return Py_NewRef (self );
1027+ return 0 ;
10261028}
10271029
10281030static PyObject *
@@ -1031,7 +1033,12 @@ list_inplace_repeat(PyObject *_self, Py_ssize_t n)
10311033 PyObject * ret ;
10321034 PyListObject * self = (PyListObject * ) _self ;
10331035 Py_BEGIN_CRITICAL_SECTION (self );
1034- ret = list_inplace_repeat_lock_held (self , n );
1036+ if (list_inplace_repeat_lock_held (self , n ) < 0 ) {
1037+ ret = NULL ;
1038+ }
1039+ else {
1040+ ret = Py_NewRef (self );
1041+ }
10351042 Py_END_CRITICAL_SECTION ();
10361043 return ret ;
10371044}
@@ -1179,7 +1186,7 @@ list_extend_fast(PyListObject *self, PyObject *iterable)
11791186}
11801187
11811188static int
1182- list_extend_iter (PyListObject * self , PyObject * iterable )
1189+ list_extend_iter_lock_held (PyListObject * self , PyObject * iterable )
11831190{
11841191 PyObject * it = PyObject_GetIter (iterable );
11851192 if (it == NULL ) {
@@ -1253,45 +1260,78 @@ list_extend_iter(PyListObject *self, PyObject *iterable)
12531260 return -1 ;
12541261}
12551262
1256-
12571263static int
1258- list_extend (PyListObject * self , PyObject * iterable )
1264+ list_extend_lock_held (PyListObject * self , PyObject * iterable )
12591265{
1260- // Special cases:
1261- // 1) lists and tuples which can use PySequence_Fast ops
1262- // 2) extending self to self requires making a copy first
1263- if (PyList_CheckExact (iterable )
1264- || PyTuple_CheckExact (iterable )
1265- || (PyObject * )self == iterable )
1266- {
1267- iterable = PySequence_Fast (iterable , "argument must be iterable" );
1268- if (!iterable ) {
1269- return -1 ;
1270- }
1271-
1272- int res = list_extend_fast (self , iterable );
1273- Py_DECREF (iterable );
1274- return res ;
1275- }
1276- else {
1277- return list_extend_iter (self , iterable );
1266+ PyObject * seq = PySequence_Fast (iterable , "argument must be iterable" );
1267+ if (!seq ) {
1268+ return -1 ;
12781269 }
1279- }
12801270
1271+ int res = list_extend_fast (self , seq );
1272+ Py_DECREF (seq );
1273+ return res ;
1274+ }
12811275
1282- PyObject *
1283- _PyList_Extend (PyListObject * self , PyObject * iterable )
1276+ static int
1277+ list_extend_set (PyListObject * self , PySetObject * other )
12841278{
1285- if (list_extend (self , iterable ) < 0 ) {
1286- return NULL ;
1279+ Py_ssize_t m = Py_SIZE (self );
1280+ Py_ssize_t n = PySet_GET_SIZE (other );
1281+ if (list_resize (self , m + n ) < 0 ) {
1282+ return -1 ;
12871283 }
1288- Py_RETURN_NONE ;
1284+ /* populate the end of self with iterable's items */
1285+ Py_ssize_t setpos = 0 ;
1286+ Py_hash_t hash ;
1287+ PyObject * key ;
1288+ PyObject * * dest = self -> ob_item + m ;
1289+ while (_PySet_NextEntry ((PyObject * )other , & setpos , & key , & hash )) {
1290+ Py_INCREF (key );
1291+ * dest = key ;
1292+ dest ++ ;
1293+ }
1294+ Py_SET_SIZE (self , m + n );
1295+ return 0 ;
12891296}
12901297
1298+ static int
1299+ _list_extend (PyListObject * self , PyObject * iterable )
1300+ {
1301+ // Special case:
1302+ // lists and tuples which can use PySequence_Fast ops
1303+ // TODO(@corona10): Add more special cases for other types.
1304+ int res = -1 ;
1305+ if ((PyObject * )self == iterable ) {
1306+ Py_BEGIN_CRITICAL_SECTION (self );
1307+ res = list_inplace_repeat_lock_held (self , 2 );
1308+ Py_END_CRITICAL_SECTION ();
1309+ }
1310+ else if (PyList_CheckExact (iterable )) {
1311+ Py_BEGIN_CRITICAL_SECTION2 (self , iterable );
1312+ res = list_extend_lock_held (self , iterable );
1313+ Py_END_CRITICAL_SECTION2 ();
1314+ }
1315+ else if (PyTuple_CheckExact (iterable )) {
1316+ Py_BEGIN_CRITICAL_SECTION (self );
1317+ res = list_extend_lock_held (self , iterable );
1318+ Py_END_CRITICAL_SECTION ();
1319+ }
1320+ else if (PyAnySet_CheckExact (iterable )) {
1321+ Py_BEGIN_CRITICAL_SECTION2 (self , iterable );
1322+ res = list_extend_set (self , (PySetObject * )iterable );
1323+ Py_END_CRITICAL_SECTION2 ();
1324+ }
1325+ else {
1326+ Py_BEGIN_CRITICAL_SECTION (self );
1327+ res = list_extend_iter_lock_held (self , iterable );
1328+ Py_END_CRITICAL_SECTION ();
1329+ }
1330+ return res ;
1331+ }
12911332
12921333/*[clinic input]
1293- @critical_section self iterable
1294- list.extend as py_list_extend
1334+ list.extend as list_extend
12951335
12961336 iterable: object
12971337 /
@@ -1300,12 +1340,20 @@ Extend list by appending elements from the iterable.
13001340[clinic start generated code]*/
13011341
13021342static PyObject *
1303- py_list_extend_impl (PyListObject * self , PyObject * iterable )
1304- /*[clinic end generated code: output=a2f115ceace2c845 input=1d42175414e1a5f3 ]*/
1343+ list_extend (PyListObject * self , PyObject * iterable )
1344+ /*[clinic end generated code: output=630fb3bca0c8e789 input=979da7597a515791 ]*/
13051345{
1306- return _PyList_Extend (self , iterable );
1346+ if (_list_extend (self , iterable ) < 0 ) {
1347+ return NULL ;
1348+ }
1349+ Py_RETURN_NONE ;
13071350}
13081351
1352+ PyObject *
1353+ _PyList_Extend (PyListObject * self , PyObject * iterable )
1354+ {
1355+ return list_extend (self , iterable );
1356+ }
13091357
13101358int
13111359PyList_Extend (PyObject * self , PyObject * iterable )
@@ -1314,7 +1362,7 @@ PyList_Extend(PyObject *self, PyObject *iterable)
13141362 PyErr_BadInternalCall ();
13151363 return -1 ;
13161364 }
1317- return list_extend ((PyListObject * )self , iterable );
1365+ return _list_extend ((PyListObject * )self , iterable );
13181366}
13191367
13201368
@@ -1334,7 +1382,7 @@ static PyObject *
13341382list_inplace_concat (PyObject * _self , PyObject * other )
13351383{
13361384 PyListObject * self = (PyListObject * )_self ;
1337- if (list_extend (self , other ) < 0 ) {
1385+ if (_list_extend (self , other ) < 0 ) {
13381386 return NULL ;
13391387 }
13401388 return Py_NewRef (self );
@@ -3168,7 +3216,7 @@ list___init___impl(PyListObject *self, PyObject *iterable)
31683216 list_clear (self );
31693217 }
31703218 if (iterable != NULL ) {
3171- if (list_extend (self , iterable ) < 0 ) {
3219+ if (_list_extend (self , iterable ) < 0 ) {
31723220 return -1 ;
31733221 }
31743222 }
@@ -3229,7 +3277,7 @@ static PyMethodDef list_methods[] = {
32293277 LIST_COPY_METHODDEF
32303278 LIST_APPEND_METHODDEF
32313279 LIST_INSERT_METHODDEF
3232- PY_LIST_EXTEND_METHODDEF
3280+ LIST_EXTEND_METHODDEF
32333281 LIST_POP_METHODDEF
32343282 LIST_REMOVE_METHODDEF
32353283 LIST_INDEX_METHODDEF
0 commit comments