|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Microsoft Corporation. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +# DeepSpeed Team |
| 6 | +""" |
| 7 | +USAGE: |
| 8 | +$ python3 script/replace_copyright.py --repo_dir ./ |
| 9 | +""" |
| 10 | + |
| 11 | +import os |
| 12 | +import argparse |
| 13 | + |
| 14 | +NEW_COPYRIGHT = ("Copyright (c) Microsoft Corporation.", "SPDX-License-Identifier: Apache-2.0", "", "DeepSpeed Team") |
| 15 | + |
| 16 | +PY_SL_COMMENT = "#" |
| 17 | +PY_ML_SINGLE = "'''" |
| 18 | +PY_ML_DOUBLE = '"""' |
| 19 | +PY_COMMENTS = (PY_SL_COMMENT, PY_ML_SINGLE, PY_ML_DOUBLE) |
| 20 | + |
| 21 | +C_SL_COMMENT = "//" |
| 22 | +C_ML_OPEN = "/*" |
| 23 | +C_ML_CLOSE = "*/" |
| 24 | +C_COMMENTS = (C_SL_COMMENT, C_ML_OPEN, C_ML_CLOSE) |
| 25 | + |
| 26 | +BASH_SL_COMMENT = "#" |
| 27 | +BASH_COMMENTS = (BASH_SL_COMMENT, ) |
| 28 | + |
| 29 | +DELIM = "|/-\|/-\|BARRIER|/-\|/-\|" # noqa: W605 |
| 30 | + |
| 31 | + |
| 32 | +def parser_args(): |
| 33 | + parser = argparse.ArgumentParser() |
| 34 | + parser.add_argument("--repo_dir", type=str, help="Repository directory") |
| 35 | + parser.add_argument("--python_style_ext", |
| 36 | + type=str, |
| 37 | + nargs="+", |
| 38 | + default=[".py"], |
| 39 | + help="File types to process with python-style comments") |
| 40 | + parser.add_argument("--bash_style_ext", |
| 41 | + type=str, |
| 42 | + nargs="+", |
| 43 | + default=[".sh"], |
| 44 | + help="File types to process with bash-style comments") |
| 45 | + parser.add_argument("--c_style_ext", |
| 46 | + type=str, |
| 47 | + nargs="+", |
| 48 | + default=[ |
| 49 | + ".c", |
| 50 | + ".cpp", |
| 51 | + ".cu", |
| 52 | + ".h", |
| 53 | + ".hpp", |
| 54 | + ".cuh", |
| 55 | + ".cc", |
| 56 | + ".hip", |
| 57 | + ".tr", |
| 58 | + ], |
| 59 | + help="File types to process with C-style comments") |
| 60 | + args = parser.parse_args() |
| 61 | + return args |
| 62 | + |
| 63 | + |
| 64 | +# These get_header_* functions are ugly, but they work :) |
| 65 | +def get_header_py(fp): |
| 66 | + with open(fp, "r") as f: |
| 67 | + lines = iter(l for l in f.readlines()) |
| 68 | + |
| 69 | + header = [] |
| 70 | + rest = [] |
| 71 | + in_multiline = False |
| 72 | + multiline_type = None |
| 73 | + |
| 74 | + while (l := next(lines, None)) is not None: |
| 75 | + l = l.strip() |
| 76 | + if l.startswith(PY_ML_SINGLE) or l.startswith(PY_ML_DOUBLE): |
| 77 | + # Detected multiline comment |
| 78 | + if in_multiline and multiline_type == l[:3]: |
| 79 | + # Ended a multiline comment |
| 80 | + in_multiline = False |
| 81 | + else: |
| 82 | + # Started a multiline comment |
| 83 | + in_multiline = True |
| 84 | + multiline_type = l[:3] |
| 85 | + if l.endswith(multiline_type) and len(l) >= 6: |
| 86 | + # Opened and closed multiline comment on single line |
| 87 | + in_multiline = False |
| 88 | + elif in_multiline and l.endswith(multiline_type): |
| 89 | + # Ended a multiline comment |
| 90 | + in_multiline = False |
| 91 | + elif not (in_multiline or l.startswith(PY_SL_COMMENT) or l == ""): |
| 92 | + # Not in a comment |
| 93 | + rest += [l + "\n"] |
| 94 | + break |
| 95 | + header.append(l) |
| 96 | + |
| 97 | + rest += list(lines) |
| 98 | + |
| 99 | + return header, rest |
| 100 | + |
| 101 | + |
| 102 | +def get_header_c(fp): |
| 103 | + with open(fp, "r") as f: |
| 104 | + lines = iter(l for l in f.readlines()) |
| 105 | + |
| 106 | + header = [] |
| 107 | + rest = [] |
| 108 | + in_multiline = False |
| 109 | + |
| 110 | + while (l := next(lines, None)) is not None: |
| 111 | + l = l.strip() |
| 112 | + if l.startswith(C_ML_OPEN): |
| 113 | + # Detected multiline comment |
| 114 | + if not l.endswith(C_ML_CLOSE): |
| 115 | + # multiline comment not closed on same line |
| 116 | + in_multiline = True |
| 117 | + elif l.endswith(C_ML_CLOSE): |
| 118 | + # Ended a multline comment |
| 119 | + in_multiline = False |
| 120 | + elif not in_multiline or l.startswith(C_SL_COMMENT) or l.isspace(): |
| 121 | + # Not in a comment |
| 122 | + rest += [l + "\n"] |
| 123 | + break |
| 124 | + header.append(l) |
| 125 | + |
| 126 | + rest += list(lines) |
| 127 | + |
| 128 | + return header, rest |
| 129 | + |
| 130 | + |
| 131 | +def get_header_bash(fp): |
| 132 | + with open(fp, "r") as f: |
| 133 | + lines = iter(l for l in f.readlines()) |
| 134 | + |
| 135 | + header = [] |
| 136 | + rest = [] |
| 137 | + |
| 138 | + while (l := next(lines, None)) is not None: |
| 139 | + l = l.strip() |
| 140 | + if not l.startswith(BASH_SL_COMMENT) or l.isspace(): |
| 141 | + # Not in a comment |
| 142 | + rest += [l + "\n"] |
| 143 | + break |
| 144 | + header.append(l) |
| 145 | + |
| 146 | + rest += list(lines) |
| 147 | + |
| 148 | + return header, rest |
| 149 | + |
| 150 | + |
| 151 | +def remove_comments(line, comment_strs): |
| 152 | + for cstr in comment_strs: |
| 153 | + line = line.replace(cstr, "") |
| 154 | + return line |
| 155 | + |
| 156 | + |
| 157 | +def format_multiline_comment(text, comment_type): |
| 158 | + if comment_type == PY_COMMENTS: |
| 159 | + text = f"\n{comment_type[2]}\n" + "\n".join(text) + f"{comment_type[2]}" |
| 160 | + if comment_type == C_COMMENTS: |
| 161 | + text = f"\n{comment_type[1]}\n" + "\n".join(text) + f"{comment_type[2]}" |
| 162 | + if comment_type == BASH_COMMENTS: |
| 163 | + text = "\n".join([f"{comment_type[0]}{l}" for l in text]) |
| 164 | + return text |
| 165 | + |
| 166 | + |
| 167 | +def modify_file_header(fp, file_header, rest_of_file, preserve_text_store, comment_type): |
| 168 | + header_text = "\n".join(file_header) |
| 169 | + if not (header_text.strip() == "" or header_text in preserve_text_store): |
| 170 | + # Unique header, need to get user input |
| 171 | + print("\n", DELIM, "\n") |
| 172 | + for idx, line in enumerate(file_header): |
| 173 | + print(f"{idx}: {line}") |
| 174 | + print("\n", DELIM, "\n") |
| 175 | + print("\nIndicate the FIRST line of the Header to KEEP") |
| 176 | + print("(shebang #! lines will be automatically processed and should not be included).") |
| 177 | + keep_idx = input("Enter number (or leave blank if no lines should be preserved): ") |
| 178 | + preserve_text_store[header_text] = file_header[int(keep_idx):] if keep_idx != "" else "" |
| 179 | + |
| 180 | + # Identify any shebang lines in the file |
| 181 | + shebang = "\n".join([l for l in file_header if l.startswith("#!")]) |
| 182 | + if shebang != "": |
| 183 | + shebang += "\n" |
| 184 | + |
| 185 | + # Get the text we should preserve in this file and process to remove comment characters |
| 186 | + text_to_preserve = preserve_text_store.get(header_text, [""]) |
| 187 | + text_to_preserve = [remove_comments(l, comment_type) for l in text_to_preserve] |
| 188 | + |
| 189 | + # Format the text we want to keep into a new multiline comment |
| 190 | + if "".join(text_to_preserve) == "": |
| 191 | + text_to_preserve = "" |
| 192 | + else: |
| 193 | + text_to_preserve = format_multiline_comment(text_to_preserve, comment_type) |
| 194 | + |
| 195 | + # Generate the copyright text we will be adding |
| 196 | + copyright_text = "\n".join([f"{comment_type[0]} {l}" if l != "" else l for l in NEW_COPYRIGHT]) |
| 197 | + |
| 198 | + # Assemble the new header |
| 199 | + new_header = shebang + copyright_text + text_to_preserve |
| 200 | + |
| 201 | + # Write out the new file |
| 202 | + new_file_contents = new_header + "\n" + "".join(rest_of_file) |
| 203 | + with open(fp, "w") as f: |
| 204 | + f.write(new_file_contents) |
| 205 | + |
| 206 | + return preserve_text_store # Return so we can reuse for future files |
| 207 | + |
| 208 | + |
| 209 | +def main(args): |
| 210 | + preserve_text_store = {} # Used to track header comments we should preserve |
| 211 | + for root, dirs, fnames in os.walk(args.repo_dir): |
| 212 | + # Walk across directory looking for all files with extensions we want to modify |
| 213 | + for ext in args.python_style_ext: |
| 214 | + fpaths = [os.path.join(root, fn) for fn in fnames if fn.endswith(ext)] |
| 215 | + for fp in fpaths: |
| 216 | + file_header, rest_of_file = get_header_py(fp) |
| 217 | + preserve_text_store = modify_file_header(fp, file_header, rest_of_file, preserve_text_store, |
| 218 | + PY_COMMENTS) |
| 219 | + for ext in args.c_style_ext: |
| 220 | + fpaths = [os.path.join(root, fn) for fn in fnames if fn.endswith(ext)] |
| 221 | + for fp in fpaths: |
| 222 | + file_header, rest_of_file = get_header_c(fp) |
| 223 | + preserve_text_store = modify_file_header(fp, file_header, rest_of_file, preserve_text_store, |
| 224 | + C_COMMENTS) |
| 225 | + for ext in args.bash_style_ext: |
| 226 | + fpaths = [os.path.join(root, fn) for fn in fnames if fn.endswith(ext)] |
| 227 | + for fp in fpaths: |
| 228 | + file_header, rest_of_file = get_header_bash(fp) |
| 229 | + preserve_text_store = modify_file_header(fp, file_header, rest_of_file, preserve_text_store, |
| 230 | + BASH_COMMENTS) |
| 231 | + |
| 232 | + |
| 233 | +if __name__ == "__main__": |
| 234 | + args = parser_args() |
| 235 | + main(args) |
0 commit comments