Skip to content

Commit 68dd064

Browse files
authored
fixit: Clean up Python sample at cloud-sql/postgres/client-side-encry… (GoogleCloudPlatform#10027)
* fixit: Clean up Python sample at cloud-sql/postgres/client-side-encryption * Removing capsys from tests
1 parent 9095542 commit 68dd064

File tree

8 files changed

+68
-64
lines changed

8 files changed

+68
-64
lines changed

cloud-sql/postgres/client-side-encryption/snippets/cloud_kms_env_aead.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
logger = logging.getLogger(__name__)
2323

2424

25-
def init_tink_env_aead(
26-
key_uri: str,
27-
credentials: str) -> tink.aead.KmsEnvelopeAead:
25+
def init_tink_env_aead(key_uri: str, credentials: str) -> tink.aead.KmsEnvelopeAead:
26+
"""
27+
Initiates the Envelope AEAD object using the KMS credentials.
28+
"""
2829
aead.register()
2930

3031
try:

cloud-sql/postgres/client-side-encryption/snippets/cloud_kms_env_aead_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@ def setup() -> str:
2626
yield kms_uri
2727

2828

29-
def test_cloud_kms_env_aead(
30-
capsys: pytest.CaptureFixture, kms_uri: str) -> None:
29+
def test_cloud_kms_env_aead(kms_uri: str) -> None:
3130
credentials = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", None)
3231
if credentials is None:
3332
raise Exception(
34-
"Environment variable GOOGLE_APPLICATION_CREDENTIALS is not set")
33+
"Environment variable GOOGLE_APPLICATION_CREDENTIALS is not set"
34+
)
3535

3636
# Create env_aead primitive
37-
init_tink_env_aead(kms_uri, credentials)
37+
envelope = init_tink_env_aead(kms_uri, credentials)
3838

39-
captured = capsys.readouterr().out
40-
assert f"Created envelope AEAD Primitive using KMS URI: {kms_uri}" in captured
39+
assert envelope.key_template == kms_uri

cloud-sql/postgres/client-side-encryption/snippets/cloud_sql_connection_pool.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
def init_tcp_connection_engine(
2020
db_user: str, db_pass: str, db_name: str, db_host: str
2121
) -> sqlalchemy.engine.base.Engine:
22+
"""
23+
Creates a connection to the database using tcp socket.
24+
"""
2225
# Remember - storing secrets in plaintext is potentially unsafe. Consider using
2326
# something like https://cloud.google.com/secret-manager/docs/overview to help keep
2427
# secrets secret.
@@ -50,6 +53,9 @@ def init_unix_connection_engine(
5053
instance_connection_name: str,
5154
db_socket_dir: str,
5255
) -> sqlalchemy.engine.base.Engine:
56+
"""
57+
Creates a connection to the database using unix socket.
58+
"""
5359
# Remember - storing secrets in plaintext is potentially unsafe. Consider using
5460
# something like https://cloud.google.com/secret-manager/docs/overview to help keep
5561
# secrets secret.
@@ -64,9 +70,9 @@ def init_unix_connection_engine(
6470
database=db_name, # e.g. "my-database-name"
6571
query={
6672
"unix_sock": "{}/{}/.s.PGSQL.5432".format(
67-
db_socket_dir, # e.g. "/cloudsql"
68-
instance_connection_name) # i.e "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
69-
}
73+
db_socket_dir, instance_connection_name # e.g. "/cloudsql"
74+
) # i.e "<PROJECT-NAME>:<INSTANCE-REGION>:<INSTANCE-NAME>"
75+
},
7076
),
7177
)
7278
print("Created Unix socket connection pool")
@@ -82,6 +88,7 @@ def init_db(
8288
db_socket_dir: str = None,
8389
db_host: str = None,
8490
) -> sqlalchemy.engine.base.Engine:
91+
"""Starts a connection to the database and creates voting table if it doesn't exist."""
8592

8693
if db_host:
8794
db = init_tcp_connection_engine(db_user, db_pass, db_name, db_host)

cloud-sql/postgres/client-side-encryption/snippets/cloud_sql_connection_pool_test.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
import uuid
1919

2020
import pytest
21+
import sqlalchemy
2122

2223
from snippets.cloud_sql_connection_pool import (
2324
init_db,
2425
init_tcp_connection_engine,
25-
init_unix_connection_engine
26+
init_unix_connection_engine,
2627
)
2728

2829

@@ -46,50 +47,46 @@ def setup() -> dict[str, str]:
4647
yield conn_vars
4748

4849

49-
def test_init_tcp_connection_engine(
50-
capsys: pytest.CaptureFixture,
51-
conn_vars: dict[str, str]) -> None:
52-
53-
init_tcp_connection_engine(
50+
def test_init_tcp_connection_engine(conn_vars: dict[str, str]) -> None:
51+
engine = init_tcp_connection_engine(
5452
db_user=conn_vars["db_user"],
5553
db_name=conn_vars["db_name"],
5654
db_pass=conn_vars["db_pass"],
5755
db_host=conn_vars["db_host"],
5856
)
5957

60-
captured = capsys.readouterr().out
61-
assert "Created TCP connection pool" in captured
62-
58+
assert isinstance(engine, sqlalchemy.engine.base.Engine)
59+
assert conn_vars["db_name"] in engine.url
6360

64-
def test_init_unix_connection_engine(
65-
capsys: pytest.CaptureFixture,
66-
conn_vars: dict[str, str]) -> None:
6761

68-
init_unix_connection_engine(
62+
def test_init_unix_connection_engine(conn_vars: dict[str, str]) -> None:
63+
engine = init_unix_connection_engine(
6964
db_user=conn_vars["db_user"],
7065
db_name=conn_vars["db_name"],
7166
db_pass=conn_vars["db_pass"],
7267
instance_connection_name=conn_vars["instance_conn_name"],
7368
db_socket_dir=conn_vars["db_socket_dir"],
7469
)
7570

76-
captured = capsys.readouterr().out
77-
assert "Created Unix socket connection pool" in captured
78-
71+
assert isinstance(engine, sqlalchemy.engine.base.Engine)
72+
assert conn_vars["db_name"] in engine.url
7973

80-
def test_init_db(
81-
capsys: pytest.CaptureFixture,
82-
conn_vars: dict[str, str]) -> None:
8374

75+
def test_init_db(conn_vars: dict[str, str]) -> None:
8476
table_name = f"votes_{uuid.uuid4().hex}"
8577

86-
init_db(
78+
engine = init_db(
8779
db_user=conn_vars["db_user"],
8880
db_name=conn_vars["db_name"],
8981
db_pass=conn_vars["db_pass"],
9082
table_name=table_name,
9183
db_host=conn_vars["db_host"],
9284
)
9385

94-
captured = capsys.readouterr().out
95-
assert f"Created table {table_name} in db {conn_vars['db_name']}" in captured
86+
assert isinstance(engine, sqlalchemy.engine.base.Engine)
87+
88+
try:
89+
with engine.connect() as conn:
90+
conn.execute(f"SELECT count(*) FROM {table_name}").all()
91+
except Exception as error:
92+
pytest.fail(f"Database wasn't initialized properly: {error}")

cloud-sql/postgres/client-side-encryption/snippets/encrypt_and_insert_data.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929

3030
def main() -> None:
31+
"""
32+
Connects to the database, encrypts and inserts some data.
33+
"""
3134
db_user = os.environ["DB_USER"] # e.g. "root", "postgres"
3235
db_pass = os.environ["DB_PASS"] # e.g. "mysupersecretpassword"
3336
db_name = os.environ["DB_NAME"] # e.g. "votes_db"
@@ -73,6 +76,10 @@ def encrypt_and_insert_data(
7376
team: str,
7477
email: str,
7578
) -> None:
79+
"""
80+
Inserts a vote into the database with email address previously encrypted using
81+
a KmsEnvelopeAead object.
82+
"""
7683
time_cast = datetime.datetime.now(tz=datetime.timezone.utc)
7784
# Use the envelope AEAD primitive to encrypt the email, using the team name as
7885
# associated data. Encryption with associated data ensures authenticity
@@ -93,11 +100,7 @@ def encrypt_and_insert_data(
93100
# Using a with statement ensures that the connection is always released
94101
# back into the pool at the end of statement (even if an error occurs)
95102
with db.connect() as conn:
96-
conn.execute(
97-
stmt,
98-
time_cast=time_cast,
99-
team=team,
100-
voter_email=encrypted_email)
103+
conn.execute(stmt, time_cast=time_cast, team=team, voter_email=encrypted_email)
101104
print(f"Vote successfully cast for '{team}' at time {time_cast}!")
102105

103106

cloud-sql/postgres/client-side-encryption/snippets/encrypt_and_insert_data_test.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,10 @@ def setup_key() -> tink.aead.KmsEnvelopeAead:
6565

6666

6767
def test_encrypt_and_insert_data(
68-
capsys: pytest.CaptureFixture,
6968
pool: sqlalchemy.engine.Engine,
70-
env_aead: tink.aead.KmsEnvelopeAead
69+
env_aead: tink.aead.KmsEnvelopeAead,
7170
) -> None:
72-
encrypt_and_insert_data(
73-
pool,
74-
env_aead,
75-
table_name,
76-
"SPACES",
77-
"hello@example.com")
78-
captured = capsys.readouterr()
71+
encrypt_and_insert_data(pool, env_aead, table_name, "SPACES", "hello@example.com")
7972

8073
decrypted_emails = []
8174
with pool.connect() as conn:
@@ -89,5 +82,4 @@ def test_encrypt_and_insert_data(
8982
email = env_aead.decrypt(row[2], team.encode()).decode()
9083
decrypted_emails.append(email)
9184

92-
assert "Vote successfully cast for 'SPACES'" in captured.out
9385
assert "hello@example.com" in decrypted_emails

cloud-sql/postgres/client-side-encryption/snippets/query_and_decrypt_data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525

2626
def main() -> None:
27+
"""
28+
Connects to the database, inserts encrypted data and retrieves encrypted data.
29+
"""
2730
db_user = os.environ["DB_USER"] # e.g. "root", "postgres"
2831
db_pass = os.environ["DB_PASS"] # e.g. "mysupersecretpassword"
2932
db_name = os.environ["DB_NAME"] # e.g. "votes_db"
@@ -67,7 +70,10 @@ def query_and_decrypt_data(
6770
db: sqlalchemy.engine.base.Engine,
6871
env_aead: tink.aead.KmsEnvelopeAead,
6972
table_name: str,
70-
) -> None:
73+
) -> list[tuple[str]]:
74+
"""
75+
Retrieves data from the database and decrypts it using the KmsEnvelopeAead object.
76+
"""
7177
with db.connect() as conn:
7278
# Execute the query and fetch all results
7379
recent_votes = conn.execute(
@@ -76,6 +82,7 @@ def query_and_decrypt_data(
7682
).fetchall()
7783

7884
print("Team\tEmail\tTime Cast")
85+
output = []
7986

8087
for row in recent_votes:
8188
team = row[0]
@@ -93,6 +100,8 @@ def query_and_decrypt_data(
93100

94101
# Print recent votes
95102
print(f"{team}\t{email}\t{time_cast}")
103+
output.append((team, email, time_cast))
104+
return output
96105

97106

98107
# [END cloud_sql_postgres_cse_query]

cloud-sql/postgres/client-side-encryption/snippets/query_and_decrypt_data_test.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,15 @@ def setup_key() -> tink.aead.KmsEnvelopeAead:
6868
def test_query_and_decrypt_data(
6969
capsys: pytest.CaptureFixture,
7070
pool: sqlalchemy.engine.Engine,
71-
env_aead: tink.aead.KmsEnvelopeAead
71+
env_aead: tink.aead.KmsEnvelopeAead,
7272
) -> None:
73-
7473
# Insert data into table before testing
75-
encrypt_and_insert_data(
76-
pool,
77-
env_aead,
78-
table_name,
79-
"SPACES",
80-
"hello@example.com")
81-
82-
query_and_decrypt_data(pool, env_aead, table_name)
83-
84-
captured = capsys.readouterr()
85-
assert "Team\tEmail\tTime Cast" in captured.out
86-
assert "hello@example.com" in captured.out
74+
encrypt_and_insert_data(pool, env_aead, table_name, "SPACES", "hello@example.com")
75+
76+
output = query_and_decrypt_data(pool, env_aead, table_name)
77+
78+
for row in output:
79+
if row[1] == "hello@example.com":
80+
break
81+
else:
82+
pytest.fail("Failed to find vote in the decrypted data.")

0 commit comments

Comments
 (0)