@@ -26,9 +26,9 @@ def parse_args():
2626 required = True ,
2727 )
2828 parser .add_argument (
29- "--pr " ,
30- type = int ,
31- help = "Number of the PR in the stack to check and create corresponding PR" ,
29+ "--ref " ,
30+ type = str ,
31+ help = "Ref fo PR in the stack to check and create corresponding PR" ,
3232 required = True ,
3333 )
3434 return parser .parse_args ()
@@ -68,12 +68,18 @@ def extract_stack_from_body(pr_body: str) -> List[int]:
6868 return list (reversed (prs ))
6969
7070
71- def get_pr_stack_from_number (pr_number : int , repo : Repository ) -> List [int ]:
71+ def get_pr_stack_from_number (ref : str , repo : Repository ) -> List [int ]:
72+ if ref .isnumeric ():
73+ pr_number = int (ref )
74+ else :
75+ branch_name = ref .replace ("refs/heads/" , "" )
76+ pr_number = repo .get_branch (branch_name ).commit .get_pulls ()[0 ].number
77+
7278 pr_stack = extract_stack_from_body (repo .get_pull (pr_number ).body )
7379
7480 if not pr_stack :
7581 raise Exception (
76- f"Could not find PR stack in body of # { pr_number } . "
82+ f"Could not find PR stack in body of ref . "
7783 + "Please make sure that the PR was created with ghstack."
7884 )
7985
@@ -129,7 +135,7 @@ def main():
129135
130136 with Github (auth = Auth .Token (os .environ ["GITHUB_TOKEN" ])) as gh :
131137 repo = gh .get_repo (args .repo )
132- create_prs_for_orig_branch (get_pr_stack_from_number (args .pr , repo ), repo )
138+ create_prs_for_orig_branch (get_pr_stack_from_number (args .ref , repo ), repo )
133139
134140
135141if __name__ == "__main__" :
0 commit comments