Skip to content

Commit 567789c

Browse files
author
ddeleo
committed
optimize this script to call native airflow rest api and use optimized bulk pause/unpause
1 parent e2bca55 commit 567789c

File tree

1 file changed

+67
-167
lines changed

1 file changed

+67
-167
lines changed

composer/tools/composer_dags.py

Lines changed: 67 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -15,192 +15,92 @@
1515
# limitations under the License.
1616
"""Standalone script to pause/unpause all the dags in the specific environment."""
1717

18-
from __future__ import annotations
19-
2018
import argparse
21-
import json
2219
import logging
23-
import re
24-
import subprocess
25-
import sys
2620
from typing import Any
2721

22+
import google.auth
23+
from google.auth.transport.requests import AuthorizedSession
24+
2825
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(message)s")
2926
logger = logging.getLogger(__name__)
3027

3128

32-
class DAG:
33-
"""Provides necessary utils for Composer DAGs."""
34-
35-
COMPOSER_AF_VERSION_RE = re.compile(
36-
"composer-(\d+)(?:\.(\d+)\.(\d+))?.*?-airflow-(\d+)\.(\d+)\.(\d+)"
37-
)
38-
39-
@staticmethod
40-
def get_list_of_dags(
41-
project_name: str,
42-
environment: str,
43-
location: str,
44-
sdk_endpoint: str,
45-
airflow_version: tuple[int],
46-
) -> list[str]:
47-
"""Retrieves the list of dags for particular project."""
48-
sub_command = "list_dags" if airflow_version < (2, 0, 0) else "dags list"
49-
command = (
50-
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer "
51-
f"environments run {environment} --project={project_name}"
52-
f" --location={location} {sub_command}"
29+
class ComposerClient:
30+
"""Client for interacting with Composer API.
31+
32+
The client uses Google Auth and Requests under the hood.
33+
"""
34+
35+
def __init__(self, project: str, location: str, sdk_endpoint: str) -> None:
36+
self.project = project
37+
self.location = location
38+
self.sdk_endpoint = sdk_endpoint.rstrip("/")
39+
self.credentials, _ = google.auth.default()
40+
self.session = AuthorizedSession(self.credentials)
41+
self._airflow_uris = {}
42+
43+
def _get_airflow_uri(self, environment_name: str) -> str:
44+
"""Returns the Airflow URI for a given environment, caching the result."""
45+
if environment_name not in self._airflow_uris:
46+
environment = self.get_environment(environment_name)
47+
self._airflow_uris[environment_name] = environment["config"]["airflowUri"]
48+
return self._airflow_uris[environment_name]
49+
50+
def get_environment(self, environment_name: str) -> Any:
51+
"""Returns an environment json for a given Composer environment."""
52+
url = (
53+
f"{self.sdk_endpoint}/v1/projects/{self.project}/locations/"
54+
f"{self.location}/environments/{environment_name}"
5355
)
54-
command_output = DAG._run_shell_command_locally_once(command=command)[1]
55-
if airflow_version < (2, 0, 0):
56-
command_output_parsed = command_output.split()
57-
return command_output_parsed[
58-
command_output_parsed.index("DAGS") + 2 : len(command_output_parsed) - 1
59-
]
60-
else:
61-
# Collecting names of DAGs for output
62-
list_of_dags = []
63-
for line in command_output.split("\n"):
64-
if re.compile("^[a-zA-Z].*").findall(line):
65-
list_of_dags.append(line.split()[0])
66-
return list_of_dags[1:]
67-
68-
@staticmethod
69-
def _run_shell_command_locally_once(
70-
command: str,
71-
command_input: str = None,
72-
log_command: bool = True,
73-
) -> tuple[int, str]:
74-
"""Executes shell command and returns its output."""
75-
76-
p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
77-
if log_command:
78-
logger.info("Executing shell command: %s", command)
79-
(res, _) = p.communicate(input=command_input)
80-
if p.returncode:
81-
logged_command = f' "{command}"' if log_command else ""
82-
error_message = (
83-
f"Failed to run shell command{logged_command}, " f"details: {res}"
56+
response = self.session.get(url)
57+
if response.status_code != 200:
58+
raise RuntimeError(
59+
f"Failed to get environment {environment_name}: {response.text}"
8460
)
85-
logger.error(error_message)
86-
sys.exit(1)
87-
return (p.returncode, str(res.decode().strip("\n")))
88-
89-
@staticmethod
90-
def pause_dag(
91-
project_name: str,
92-
environment: str,
93-
location: str,
94-
sdk_endpoint: str,
95-
dag_id: str,
96-
airflow_version: list[int],
97-
) -> str:
98-
"""Pause specific DAG in the given environment."""
99-
sub_command = "pause" if airflow_version < (2, 0, 0) else "dags pause"
100-
command = (
101-
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments"
102-
f" run {environment} --project={project_name} --location={location}"
103-
f" {sub_command} -- {dag_id}"
104-
)
105-
command_output = DAG._run_shell_command_locally_once(command=command)
106-
if command_output[0] == 1:
107-
logger.info(command_output[1])
108-
logger.info("Error pausing DAG %s, Retrying...", dag_id)
109-
command_output = DAG._run_shell_command_locally_once(command=command)
110-
if command_output[0] == 1:
111-
logger.info("Unable to pause DAG %s", dag_id)
112-
logger.info(command_output[1])
113-
114-
@staticmethod
115-
def unpause_dag(
116-
project_name: str,
117-
environment: str,
118-
location: str,
119-
sdk_endpoint: str,
120-
dag_id: str,
121-
airflow_version: list[int],
122-
) -> str:
123-
"""UnPause specific DAG in the given environment."""
124-
sub_command = "unpause" if airflow_version < (2, 0, 0) else "dags unpause"
125-
command = (
126-
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments"
127-
f" run {environment} --project={project_name} --location={location}"
128-
f" {sub_command} -- {dag_id}"
129-
)
130-
command_output = DAG._run_shell_command_locally_once(command=command)
131-
if command_output[0] == 1:
132-
logger.info(command_output[1])
133-
logger.info("Error Unpausing DAG %s, Retrying...", dag_id)
134-
command_output = DAG._run_shell_command_locally_once(command=command)
135-
if command_output[0] == 1:
136-
logger.info("Unable to Unpause DAG %s", dag_id)
137-
logger.info(command_output[1])
138-
139-
@staticmethod
140-
def describe_environment(
141-
project_name: str, environment: str, location: str, sdk_endpoint: str
142-
) -> Any:
143-
"""Returns the given environment json object to parse necessary details."""
144-
logger.info("*** Fetching details of the environment: %s...", environment)
145-
command = (
146-
f"CLOUDSDK_API_ENDPOINT_OVERRIDES_COMPOSER={sdk_endpoint} gcloud composer environments"
147-
f" describe {environment} --project={project_name} --location={location}"
148-
f" --format json"
149-
)
150-
environment_json = json.loads(DAG._run_shell_command_locally_once(command)[1])
151-
logger.info("Environment Info:\n %s", environment_json["name"])
152-
return environment_json
61+
return response.json()
62+
63+
def pause_all_dags(self, environment_name: str) -> Any:
64+
"""Pauses all DAGs in a Composer environment."""
65+
airflow_uri = self._get_airflow_uri(environment_name)
66+
67+
url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%"
68+
response = self.session.patch(url, json={"is_paused": True})
69+
if response.status_code != 200:
70+
raise RuntimeError(f"Failed to pause all DAGs: {response.text}")
71+
72+
def unpause_all_dags(self, environment_name: str) -> Any:
73+
"""Unpauses all DAGs in a Composer environment."""
74+
airflow_uri = self._get_airflow_uri(environment_name)
75+
76+
url = f"{airflow_uri}/api/v1/dags?dag_id_pattern=%"
77+
response = self.session.patch(url, json={"is_paused": False})
78+
if response.status_code != 200:
79+
raise RuntimeError(f"Failed to unpause all DAGs: {response.text}")
15380

15481

15582
def main(
156-
project_name: str, environment: str, location: str, operation: str, sdk_endpoint=str
83+
project_name: str,
84+
environment: str,
85+
location: str,
86+
operation: str,
87+
sdk_endpoint: str,
15788
) -> int:
15889
logger.info("DAG Pause/UnPause Script for Cloud Composer")
159-
environment_info = DAG.describe_environment(
160-
project_name=project_name,
161-
environment=environment,
162-
location=location,
163-
sdk_endpoint=sdk_endpoint,
164-
)
165-
versions = DAG.COMPOSER_AF_VERSION_RE.match(
166-
environment_info["config"]["softwareConfig"]["imageVersion"]
167-
).groups()
168-
logger.info(
169-
"Image version: %s",
170-
environment_info["config"]["softwareConfig"]["imageVersion"],
171-
)
172-
airflow_version = (int(versions[3]), int(versions[4]), int(versions[5]))
173-
list_of_dags = DAG.get_list_of_dags(
174-
project_name=project_name,
175-
environment=environment,
176-
location=location,
177-
sdk_endpoint=sdk_endpoint,
178-
airflow_version=airflow_version,
90+
91+
client = ComposerClient(
92+
project=project_name, location=location, sdk_endpoint=sdk_endpoint
17993
)
180-
logger.info("List of dags : %s", list_of_dags)
18194

18295
if operation == "pause":
183-
for dag in list_of_dags:
184-
if dag == "airflow_monitoring":
185-
continue
186-
DAG.pause_dag(
187-
project_name=project_name,
188-
environment=environment,
189-
location=location,
190-
sdk_endpoint=sdk_endpoint,
191-
dag_id=dag,
192-
airflow_version=airflow_version,
193-
)
96+
logger.info("Pausing all DAGs in the environment...")
97+
client.pause_all_dags(environment)
98+
logger.info("All DAGs paused.")
19499
else:
195-
for dag in list_of_dags:
196-
DAG.unpause_dag(
197-
project_name=project_name,
198-
environment=environment,
199-
location=location,
200-
sdk_endpoint=sdk_endpoint,
201-
dag_id=dag,
202-
airflow_version=airflow_version,
203-
)
100+
# Optimization: use bulk unpause
101+
logger.info("Unpausing all DAGs in the environment...")
102+
client.unpause_all_dags(environment)
103+
logger.info("All DAGs unpaused.")
204104
return 0
205105

206106

0 commit comments

Comments
 (0)