Skip to content

Commit 6bd94de

Browse files
authored
bpo-42567: [Enum] call __init_subclass__ after members are added (GH-23714)
When creating an Enum, type.__new__ calls __init_subclass__, but at that point the members have not been added. This patch suppresses the initial call, then manually calls the ancestor __init_subclass__ before returning the new Enum class.
1 parent 2a35137 commit 6bd94de

3 files changed

Lines changed: 104 additions & 2 deletions

File tree

Lib/enum.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
]
1010

1111

12+
class _NoInitSubclass:
13+
"""
14+
temporary base class to suppress calling __init_subclass__
15+
"""
16+
@classmethod
17+
def __init_subclass__(cls, **kwds):
18+
pass
19+
1220
def _is_descriptor(obj):
1321
"""
1422
Returns True if obj is a descriptor, False otherwise.
@@ -157,7 +165,7 @@ def __prepare__(metacls, cls, bases):
157165
)
158166
return enum_dict
159167

160-
def __new__(metacls, cls, bases, classdict):
168+
def __new__(metacls, cls, bases, classdict, **kwds):
161169
# an Enum class is final once enumeration items have been defined; it
162170
# cannot be mixed with other types (int, float, etc.) if it has an
163171
# inherited __new__ unless a new __new__ is defined (or the resulting
@@ -192,8 +200,22 @@ def __new__(metacls, cls, bases, classdict):
192200
if '__doc__' not in classdict:
193201
classdict['__doc__'] = 'An enumeration.'
194202

203+
# postpone calling __init_subclass__
204+
if '__init_subclass__' in classdict and classdict['__init_subclass__'] is None:
205+
raise TypeError('%s.__init_subclass__ cannot be None')
206+
# remove current __init_subclass__ so previous one can be found with getattr
207+
new_init_subclass = classdict.pop('__init_subclass__', None)
195208
# create our new Enum type
196-
enum_class = super().__new__(metacls, cls, bases, classdict)
209+
if bases:
210+
bases = (_NoInitSubclass, ) + bases
211+
enum_class = type.__new__(metacls, cls, bases, classdict)
212+
enum_class.__bases__ = enum_class.__bases__[1:] #or (object, )
213+
else:
214+
enum_class = type.__new__(metacls, cls, bases, classdict)
215+
old_init_subclass = getattr(enum_class, '__init_subclass__', None)
216+
# and restore the new one (if there was one)
217+
if new_init_subclass is not None:
218+
enum_class.__init_subclass__ = classmethod(new_init_subclass)
197219
enum_class._member_names_ = [] # names in definition order
198220
enum_class._member_map_ = {} # name->value map
199221
enum_class._member_type_ = member_type
@@ -305,6 +327,9 @@ def __new__(metacls, cls, bases, classdict):
305327
if _order_ != enum_class._member_names_:
306328
raise TypeError('member order does not match _order_')
307329

330+
# finally, call parents' __init_subclass__
331+
if Enum is not None and old_init_subclass is not None:
332+
old_init_subclass(**kwds)
308333
return enum_class
309334

310335
def __bool__(self):
@@ -682,6 +707,9 @@ def _generate_next_value_(name, start, count, last_values):
682707
else:
683708
return start
684709

710+
def __init_subclass__(cls, **kwds):
711+
super().__init_subclass__(**kwds)
712+
685713
@classmethod
686714
def _missing_(cls, value):
687715
return None

Lib/test/test_enum.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,43 @@ class ThirdFailedStrEnum(StrEnum):
21172117
class ThirdFailedStrEnum(StrEnum):
21182118
one = '1'
21192119
two = b'2', 'ascii', 9
2120+
2121+
def test_init_subclass(self):
2122+
class MyEnum(Enum):
2123+
def __init_subclass__(cls, **kwds):
2124+
super(MyEnum, cls).__init_subclass__(**kwds)
2125+
self.assertFalse(cls.__dict__.get('_test', False))
2126+
cls._test1 = 'MyEnum'
2127+
#
2128+
class TheirEnum(MyEnum):
2129+
def __init_subclass__(cls, **kwds):
2130+
super().__init_subclass__(**kwds)
2131+
cls._test2 = 'TheirEnum'
2132+
class WhoseEnum(TheirEnum):
2133+
def __init_subclass__(cls, **kwds):
2134+
pass
2135+
class NoEnum(WhoseEnum):
2136+
ONE = 1
2137+
self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
2138+
self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
2139+
self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
2140+
self.assertFalse(NoEnum.__dict__.get('_test1', False))
2141+
self.assertFalse(NoEnum.__dict__.get('_test2', False))
2142+
#
2143+
class OurEnum(MyEnum):
2144+
def __init_subclass__(cls, **kwds):
2145+
cls._test2 = 'OurEnum'
2146+
class WhereEnum(OurEnum):
2147+
def __init_subclass__(cls, **kwds):
2148+
pass
2149+
class NeverEnum(WhereEnum):
2150+
ONE = 'one'
2151+
self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
2152+
self.assertFalse(WhereEnum.__dict__.get('_test1', False))
2153+
self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
2154+
self.assertFalse(NeverEnum.__dict__.get('_test1', False))
2155+
self.assertFalse(NeverEnum.__dict__.get('_test2', False))
2156+
21202157

21212158
class TestOrder(unittest.TestCase):
21222159

@@ -2573,6 +2610,42 @@ def cycle_enum():
25732610
'at least one thread failed while creating composite members')
25742611
self.assertEqual(256, len(seen), 'too many composite members created')
25752612

2613+
def test_init_subclass(self):
2614+
class MyEnum(Flag):
2615+
def __init_subclass__(cls, **kwds):
2616+
super().__init_subclass__(**kwds)
2617+
self.assertFalse(cls.__dict__.get('_test', False))
2618+
cls._test1 = 'MyEnum'
2619+
#
2620+
class TheirEnum(MyEnum):
2621+
def __init_subclass__(cls, **kwds):
2622+
super(TheirEnum, cls).__init_subclass__(**kwds)
2623+
cls._test2 = 'TheirEnum'
2624+
class WhoseEnum(TheirEnum):
2625+
def __init_subclass__(cls, **kwds):
2626+
pass
2627+
class NoEnum(WhoseEnum):
2628+
ONE = 1
2629+
self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum')
2630+
self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum')
2631+
self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum')
2632+
self.assertFalse(NoEnum.__dict__.get('_test1', False))
2633+
self.assertFalse(NoEnum.__dict__.get('_test2', False))
2634+
#
2635+
class OurEnum(MyEnum):
2636+
def __init_subclass__(cls, **kwds):
2637+
cls._test2 = 'OurEnum'
2638+
class WhereEnum(OurEnum):
2639+
def __init_subclass__(cls, **kwds):
2640+
pass
2641+
class NeverEnum(WhereEnum):
2642+
ONE = 1
2643+
self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum')
2644+
self.assertFalse(WhereEnum.__dict__.get('_test1', False))
2645+
self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum')
2646+
self.assertFalse(NeverEnum.__dict__.get('_test1', False))
2647+
self.assertFalse(NeverEnum.__dict__.get('_test2', False))
2648+
25762649

25772650
class TestIntFlag(unittest.TestCase):
25782651
"""Tests of the IntFlags."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`Enum`: call `__init_subclass__` after members have been added

0 commit comments

Comments
 (0)