|
| 1 | +# Copyright 2018 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
1 | 15 | from contextlib import contextmanager |
2 | 16 | import logging |
3 | 17 | import os |
4 | 18 | from typing import Dict |
5 | 19 | import uuid |
6 | 20 |
|
7 | | -import pymysql |
8 | 21 | import pytest |
| 22 | +import sqlalchemy |
9 | 23 |
|
10 | 24 | import main |
11 | 25 |
|
@@ -57,24 +71,31 @@ def unix_db_connection(): |
57 | 71 |
|
58 | 72 |
|
59 | 73 | def _common_setup(): |
| 74 | + pool = main.init_connection_engine() |
| 75 | + |
| 76 | + table_name: str = uuid.uuid4().hex |
| 77 | + |
60 | 78 | try: |
61 | | - pool = main.init_connection_engine() |
62 | | - except pymysql.err.OperationalError as e: |
| 79 | + with pool.connect() as conn: |
| 80 | + conn.execute( |
| 81 | + f"CREATE TABLE IF NOT EXISTS `{table_name}`" |
| 82 | + "( vote_id SERIAL NOT NULL, time_cast timestamp NOT NULL, " |
| 83 | + "candidate CHAR(6) NOT NULL, PRIMARY KEY (vote_id) );" |
| 84 | + ) |
| 85 | + except sqlalchemy.exc.OperationalError as e: |
63 | 86 | logger.warning( |
64 | 87 | "Could not connect to the production database. " |
65 | 88 | "If running tests locally, is the cloud_sql_proxy currently running?" |
66 | 89 | ) |
| 90 | + # If there is cloud sql proxy log, dump the contents. |
| 91 | + home_dir = os.environ.get("HOME", "") |
| 92 | + log_file = f"{home_dir}/cloud_sql_proxy.log" |
| 93 | + if home_dir and os.path.isfile(log_file): |
| 94 | + print(f"Dumping the contents of {log_file}") |
| 95 | + with open(log_file, "r") as f: |
| 96 | + print(f.read()) |
67 | 97 | raise e |
68 | 98 |
|
69 | | - table_name: str = uuid.uuid4().hex |
70 | | - |
71 | | - with pool.connect() as conn: |
72 | | - conn.execute( |
73 | | - f"CREATE TABLE IF NOT EXISTS `{table_name}`" |
74 | | - "( vote_id SERIAL NOT NULL, time_cast timestamp NOT NULL, " |
75 | | - "candidate CHAR(6) NOT NULL, PRIMARY KEY (vote_id) );" |
76 | | - ) |
77 | | - |
78 | 99 | yield pool |
79 | 100 |
|
80 | 101 | with pool.connect() as conn: |
|
0 commit comments