|
15 | 15 | # limitations under the License. |
16 | 16 | """Standalone script to pause/unpause all the dags in the specific environment.""" |
17 | 17 |
|
18 | | -from __future__ import annotations |
19 | | - |
20 | 18 | import argparse |
21 | | -import json |
22 | 19 | import logging |
23 | | -import re |
24 | | -import subprocess |
25 | | -import sys |
26 | 20 | from typing import Any |
27 | 21 |
|
| 22 | +import google.auth |
| 23 | +from google.auth.transport.requests import AuthorizedSession |
| 24 | + |
28 | 25 | logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(message)s") |
29 | 26 | logger = logging.getLogger(__name__) |
30 | 27 |
|
31 | 28 |
|
32 | | -class DAG: |
33 | | - """Provides necessary utils for Composer DAGs.""" |
34 | | - |
35 | | - COMPOSER_AF_VERSION_RE = re.compile( |
36 | | - "composer-(\d+)(?:\.(\d+)\.(\d+))?.*?-airflow-(\d+)\.(\d+)\.(\d+)" |
37 | | - ) |
38 | | - |
39 | | - @staticmethod |
40 | | - def get_list_of_dags( |
41 | | - project_name: str, |
42 | | - environment: str, |
43 | | - location: str, |
44 | | - sdk_endpoint: str, |
45 | | - airflow_version: tuple[int], |
46 | | - ) -> list[str]: |
47 | | - """Retrieves the list of dags for particular project.""" |
48 | | - sub_command = "list_dags" if airflow_version < (2, 0, 0) else "dags list" |
49 | | - command = ( |
50 | | - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer " |
51 | | - f"environments run {environment} --project={project_name}" |
52 | | - f" --location={location} {sub_command}" |
| 29 | +class ComposerClient: |
| 30 | + """Client for interacting with Composer API. |
| 31 | +
|
| 32 | + The client uses Google Auth and Requests under the hood. |
| 33 | + """ |
| 34 | + |
| 35 | + def __init__(self, project: str, location: str, sdk_endpoint: str) -> None: |
| 36 | + self.project = project |
| 37 | + self.location = location |
| 38 | + self.sdk_endpoint = sdk_endpoint.rstrip("/") |
| 39 | + self.credentials, _ = google.auth.default() |
| 40 | + self.session = AuthorizedSession(self.credentials) |
| 41 | + self._airflow_uris = {} |
| 42 | + |
| 43 | + def _get_airflow_uri(self, environment_name: str) -> str: |
| 44 | + """Returns the Airflow URI for a given environment, caching the result.""" |
| 45 | + if environment_name not in self._airflow_uris: |
| 46 | + environment = self.get_environment(environment_name) |
| 47 | + self._airflow_uris[environment_name] = environment["config"]["airflowUri"] |
| 48 | + return self._airflow_uris[environment_name] |
| 49 | + |
| 50 | + def get_environment(self, environment_name: str) -> Any: |
| 51 | + """Returns an environment json for a given Composer environment.""" |
| 52 | + url = ( |
| 53 | + f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/" |
| 54 | + f"{self.location}/environments/{environment_name}" |
53 | 55 | ) |
54 | | - command_output = DAG._run_shell_command_locally_once(command=command)[1] |
55 | | - if airflow_version < (2, 0, 0): |
56 | | - command_output_parsed = command_output.split() |
57 | | - return command_output_parsed[ |
58 | | - command_output_parsed.index("DAGS") + 2 : len(command_output_parsed) - 1 |
59 | | - ] |
60 | | - else: |
61 | | - # Collecting names of DAGs for output |
62 | | - list_of_dags = [] |
63 | | - for line in command_output.split("\n"): |
64 | | - if re.compile("^[a-zA-Z].*").findall(line): |
65 | | - list_of_dags.append(line.split()[0]) |
66 | | - return list_of_dags[1:] |
67 | | - |
68 | | - @staticmethod |
69 | | - def _run_shell_command_locally_once( |
70 | | - command: str, |
71 | | - command_input: str = None, |
72 | | - log_command: bool = True, |
73 | | - ) -> tuple[int, str]: |
74 | | - """Executes shell command and returns its output.""" |
75 | | - |
76 | | - p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) |
77 | | - if log_command: |
78 | | - logger.info("Executing shell command: %s", command) |
79 | | - (res, _) = p.communicate(input=command_input) |
80 | | - if p.returncode: |
81 | | - logged_command = f' "{command}"' if log_command else "" |
82 | | - error_message = ( |
83 | | - f"Failed to run shell command{logged_command}, " f"details: {res}" |
| 56 | + response = self.session.get(url) |
| 57 | + if response.status_code != 200: |
| 58 | + raise RuntimeError( |
| 59 | + f"Failed to get environment {environment_name}: {response.text}" |
84 | 60 | ) |
85 | | - logger.error(error_message) |
86 | | - sys.exit(1) |
87 | | - return (p.returncode, str(res.decode().strip("\n"))) |
88 | | - |
89 | | - @staticmethod |
90 | | - def pause_dag( |
91 | | - project_name: str, |
92 | | - environment: str, |
93 | | - location: str, |
94 | | - sdk_endpoint: str, |
95 | | - dag_id: str, |
96 | | - airflow_version: list[int], |
97 | | - ) -> str: |
98 | | - """Pause specific DAG in the given environment.""" |
99 | | - sub_command = "pause" if airflow_version < (2, 0, 0) else "dags pause" |
100 | | - command = ( |
101 | | - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" |
102 | | - f" run {environment} --project={project_name} --location={location}" |
103 | | - f" {sub_command} -- {dag_id}" |
104 | | - ) |
105 | | - command_output = DAG._run_shell_command_locally_once(command=command) |
106 | | - if command_output[0] == 1: |
107 | | - logger.info(command_output[1]) |
108 | | - logger.info("Error pausing DAG %s, Retrying...", dag_id) |
109 | | - command_output = DAG._run_shell_command_locally_once(command=command) |
110 | | - if command_output[0] == 1: |
111 | | - logger.info("Unable to pause DAG %s", dag_id) |
112 | | - logger.info(command_output[1]) |
113 | | - |
114 | | - @staticmethod |
115 | | - def unpause_dag( |
116 | | - project_name: str, |
117 | | - environment: str, |
118 | | - location: str, |
119 | | - sdk_endpoint: str, |
120 | | - dag_id: str, |
121 | | - airflow_version: list[int], |
122 | | - ) -> str: |
123 | | - """UnPause specific DAG in the given environment.""" |
124 | | - sub_command = "unpause" if airflow_version < (2, 0, 0) else "dags unpause" |
125 | | - command = ( |
126 | | - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" |
127 | | - f" run {environment} --project={project_name} --location={location}" |
128 | | - f" {sub_command} -- {dag_id}" |
129 | | - ) |
130 | | - command_output = DAG._run_shell_command_locally_once(command=command) |
131 | | - if command_output[0] == 1: |
132 | | - logger.info(command_output[1]) |
133 | | - logger.info("Error Unpausing DAG %s, Retrying...", dag_id) |
134 | | - command_output = DAG._run_shell_command_locally_once(command=command) |
135 | | - if command_output[0] == 1: |
136 | | - logger.info("Unable to Unpause DAG %s", dag_id) |
137 | | - logger.info(command_output[1]) |
138 | | - |
139 | | - @staticmethod |
140 | | - def describe_environment( |
141 | | - project_name: str, environment: str, location: str, sdk_endpoint: str |
142 | | - ) -> Any: |
143 | | - """Returns the given environment json object to parse necessary details.""" |
144 | | - logger.info("*** Fetching details of the environment: %s...", environment) |
145 | | - command = ( |
146 | | - f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments" |
147 | | - f" describe {environment} --project={project_name} --location={location}" |
148 | | - f" --format json" |
149 | | - ) |
150 | | - environment_json = json.loads(DAG._run_shell_command_locally_once(command)[1]) |
151 | | - logger.info("Environment Info:\n %s", environment_json["name"]) |
152 | | - return environment_json |
| 61 | + return response.json() |
| 62 | + |
| 63 | + def pause_all_dags(self, environment_name: str) -> Any: |
| 64 | + """Pauses all DAGs in a Composer environment.""" |
| 65 | + airflow_uri = self._get_airflow_uri(environment_name) |
| 66 | + |
| 67 | + url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" |
| 68 | + response = self.session.patch(url, json={"is_paused": True}) |
| 69 | + if response.status_code != 200: |
| 70 | + raise RuntimeError(f"Failed to pause all DAGs: {response.text}") |
| 71 | + |
| 72 | + def unpause_all_dags(self, environment_name: str) -> Any: |
| 73 | + """Unpauses all DAGs in a Composer environment.""" |
| 74 | + airflow_uri = self._get_airflow_uri(environment_name) |
| 75 | + |
| 76 | + url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%" |
| 77 | + response = self.session.patch(url, json={"is_paused": False}) |
| 78 | + if response.status_code != 200: |
| 79 | + raise RuntimeError(f"Failed to unpause all DAGs: {response.text}") |
153 | 80 |
|
154 | 81 |
|
155 | 82 | def main( |
156 | | - project_name: str, environment: str, location: str, operation: str, sdk_endpoint=str |
| 83 | + project_name: str, |
| 84 | + environment: str, |
| 85 | + location: str, |
| 86 | + operation: str, |
| 87 | + sdk_endpoint: str, |
157 | 88 | ) -> int: |
158 | 89 | logger.info("DAG Pause/UnPause Script for Cloud Composer") |
159 | | - environment_info = DAG.describe_environment( |
160 | | - project_name=project_name, |
161 | | - environment=environment, |
162 | | - location=location, |
163 | | - sdk_endpoint=sdk_endpoint, |
164 | | - ) |
165 | | - versions = DAG.COMPOSER_AF_VERSION_RE.match( |
166 | | - environment_info["config"]["softwareConfig"]["imageVersion"] |
167 | | - ).groups() |
168 | | - logger.info( |
169 | | - "Image version: %s", |
170 | | - environment_info["config"]["softwareConfig"]["imageVersion"], |
171 | | - ) |
172 | | - airflow_version = (int(versions[3]), int(versions[4]), int(versions[5])) |
173 | | - list_of_dags = DAG.get_list_of_dags( |
174 | | - project_name=project_name, |
175 | | - environment=environment, |
176 | | - location=location, |
177 | | - sdk_endpoint=sdk_endpoint, |
178 | | - airflow_version=airflow_version, |
| 90 | + |
| 91 | + client = ComposerClient( |
| 92 | + project=project_name, location=location, sdk_endpoint=sdk_endpoint |
179 | 93 | ) |
180 | | - logger.info("List of dags : %s", list_of_dags) |
181 | 94 |
|
182 | 95 | if operation == "pause": |
183 | | - for dag in list_of_dags: |
184 | | - if dag == "airflow_monitoring": |
185 | | - continue |
186 | | - DAG.pause_dag( |
187 | | - project_name=project_name, |
188 | | - environment=environment, |
189 | | - location=location, |
190 | | - sdk_endpoint=sdk_endpoint, |
191 | | - dag_id=dag, |
192 | | - airflow_version=airflow_version, |
193 | | - ) |
| 96 | + logger.info("Pausing all DAGs in the environment...") |
| 97 | + client.pause_all_dags(environment) |
| 98 | + logger.info("All DAGs paused.") |
194 | 99 | else: |
195 | | - for dag in list_of_dags: |
196 | | - DAG.unpause_dag( |
197 | | - project_name=project_name, |
198 | | - environment=environment, |
199 | | - location=location, |
200 | | - sdk_endpoint=sdk_endpoint, |
201 | | - dag_id=dag, |
202 | | - airflow_version=airflow_version, |
203 | | - ) |
| 100 | + # Optimization: use bulk unpause |
| 101 | + logger.info("Unpausing all DAGs in the environment...") |
| 102 | + client.unpause_all_dags(environment) |
| 103 | + logger.info("All DAGs unpaused.") |
204 | 104 | return 0 |
205 | 105 |
|
206 | 106 |
|
|
0 commit comments