@@ -438,28 +438,28 @@ def raise_stmt(self, nodelist):
438438 return n
439439
440440 def import_stmt (self , nodelist ):
441- # import_stmt: 'import' dotted_as_name (',' dotted_as_name)* |
442- # from: 'from' dotted_name 'import'
443- # ('*' | import_as_name (',' import_as_name)*)
444- if nodelist [0 ][1 ] == 'from' :
445- names = []
446- if nodelist [3 ][0 ] == token .NAME :
447- for i in range (3 , len (nodelist ), 2 ):
448- names .append ((nodelist [i ][1 ], None ))
449- else :
450- for i in range (3 , len (nodelist ), 2 ):
451- names .append (self .com_import_as_name (nodelist [i ]))
452- n = From (self .com_dotted_name (nodelist [1 ]), names )
453- n .lineno = nodelist [0 ][2 ]
454- return n
441+ # import_stmt: import_name | import_from
442+ assert len (nodelist ) == 1
443+ return self .com_node (nodelist [0 ])
455444
456- if nodelist [1 ][0 ] == symbol .dotted_name :
457- names = [(self .com_dotted_name (nodelist [1 ][1 :]), None )]
445+ def import_name (self , nodelist ):
446+ # import_name: 'import' dotted_as_names
447+ n = Import (self .com_dotted_as_names (nodelist [1 ]))
448+ n .lineno = nodelist [0 ][2 ]
449+ return n
450+
451+ def import_from (self , nodelist ):
452+ # import_from: 'from' dotted_name 'import' ('*' |
453+ # '(' import_as_names ')' | import_as_names)
454+ assert nodelist [0 ][1 ] == 'from'
455+ assert nodelist [1 ][0 ] == symbol .dotted_name
456+ assert nodelist [2 ][1 ] == 'import'
457+ fromname = self .com_dotted_name (nodelist [1 ])
458+ if nodelist [3 ][0 ] == token .STAR :
459+ n = From (fromname , [('*' , None )])
458460 else :
459- names = []
460- for i in range (1 , len (nodelist ), 2 ):
461- names .append (self .com_dotted_as_name (nodelist [i ]))
462- n = Import (names )
461+ node = nodelist [3 + (nodelist [3 ][0 ] == token .LPAR )]
462+ n = From (fromname , self .com_import_as_names (node ))
463463 n .lineno = nodelist [0 ][2 ]
464464 return n
465465
@@ -895,29 +895,41 @@ def com_dotted_name(self, node):
895895 return name [:- 1 ]
896896
897897 def com_dotted_as_name (self , node ):
898- dot = self .com_dotted_name (node [1 ])
899- if len (node ) <= 2 :
898+ assert node [0 ] == symbol .dotted_as_name
899+ node = node [1 :]
900+ dot = self .com_dotted_name (node [0 ][1 :])
901+ if len (node ) == 1 :
900902 return dot , None
901- if node [0 ] == symbol .dotted_name :
902- pass
903- else :
904- assert node [2 ][1 ] == 'as'
905- assert node [3 ][0 ] == token .NAME
906- return dot , node [3 ][1 ]
903+ assert node [1 ][1 ] == 'as'
904+ assert node [2 ][0 ] == token .NAME
905+ return dot , node [2 ][1 ]
906+
907+ def com_dotted_as_names (self , node ):
908+ assert node [0 ] == symbol .dotted_as_names
909+ node = node [1 :]
910+ names = [self .com_dotted_as_name (node [0 ])]
911+ for i in range (2 , len (node ), 2 ):
912+ names .append (self .com_dotted_as_name (node [i ]))
913+ return names
907914
908915 def com_import_as_name (self , node ):
909- if node [0 ] == token .STAR :
910- return '*' , None
911916 assert node [0 ] == symbol .import_as_name
912917 node = node [1 :]
918+ assert node [0 ][0 ] == token .NAME
913919 if len (node ) == 1 :
914- assert node [0 ][0 ] == token .NAME
915920 return node [0 ][1 ], None
916-
917921 assert node [1 ][1 ] == 'as' , node
918922 assert node [2 ][0 ] == token .NAME
919923 return node [0 ][1 ], node [2 ][1 ]
920924
925+ def com_import_as_names (self , node ):
926+ assert node [0 ] == symbol .import_as_names
927+ node = node [1 :]
928+ names = [self .com_import_as_name (node [0 ])]
929+ for i in range (2 , len (node ), 2 ):
930+ names .append (self .com_import_as_name (node [i ]))
931+ return names
932+
921933 def com_bases (self , node ):
922934 bases = []
923935 for i in range (1 , len (node ), 2 ):
0 commit comments