Skip to content

Commit 77adf99

Browse files
committed
test(bigquery): add connection debug logs to diagnose socket leaks
1 parent 5f5f6c9 commit 77adf99

6 files changed

Lines changed: 208 additions & 60 deletions

File tree

packages/google-cloud-bigquery/google/cloud/bigquery/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
except ImportError:
124124
bigquery_magics = None
125125

126-
if sys.version_info < (3, 10):
126+
if sys.version_info < (3, 10): # pragma: NO COVER
127127
warnings.warn(
128128
"The python-bigquery library no longer supports Python <= 3.9. "
129129
f"Your Python version is {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}. We "

packages/google-cloud-bigquery/google/cloud/bigquery/table.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2353,7 +2353,9 @@ def to_arrow(
23532353
progress_bar.close()
23542354
finally:
23552355
if owns_bqstorage_client:
2356-
bqstorage_client._transport.close()
2356+
# mypy: bqstorage_client is guaranteed to be not None when owns_bqstorage_client is True,
2357+
# but mypy cannot infer this correlation. We ignore the union-attr error here.
2358+
bqstorage_client._transport.close() # type: ignore[union-attr]
23572359

23582360
if record_batches and bqstorage_client is not None:
23592361
return pyarrow.Table.from_batches(record_batches)

packages/google-cloud-bigquery/tests/system/test_client.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import contextlib
17+
import google.auth.transport.requests
1618
import copy
1719
import csv
1820
import datetime
@@ -155,6 +157,25 @@ def _load_json_schema(filename="schema.json"):
155157
return _parse_schema_resource(json.load(schema_file))
156158

157159

160+
@contextlib.contextmanager
161+
def patch_tracked_requests():
162+
original_init = google.auth.transport.requests.Request.__init__
163+
tracked_requests = []
164+
165+
def patched_init(self, session=None):
166+
original_init(self, session=session)
167+
tracked_requests.append(self)
168+
169+
google.auth.transport.requests.Request.__init__ = patched_init
170+
try:
171+
yield tracked_requests
172+
finally:
173+
google.auth.transport.requests.Request.__init__ = original_init
174+
for req in tracked_requests:
175+
if hasattr(req, "session") and req.session is not None:
176+
req.session.close()
177+
178+
158179
class Config(object):
159180
"""Run-time configuration to be modified at set-up.
160181
@@ -234,23 +255,34 @@ def _create_bucket(self, bucket_name, location=None):
234255

235256
def test_close_releases_open_sockets(self):
236257
current_process = psutil.Process()
237-
conn_count_start = len(current_process.net_connections())
258+
conn_start = current_process.net_connections()
259+
conn_count_start = len(conn_start)
260+
261+
with patch_tracked_requests() as tracked_requests:
262+
client = Config.CLIENT
263+
client.query(
264+
"""
265+
SELECT
266+
source_year AS year, COUNT(is_male) AS birth_count
267+
FROM `bigquery-public-data.samples.natality`
268+
GROUP BY year
269+
ORDER BY year DESC
270+
LIMIT 15
271+
"""
272+
)
238273

239-
client = Config.CLIENT
240-
client.query(
241-
"""
242-
SELECT
243-
source_year AS year, COUNT(is_male) AS birth_count
244-
FROM `bigquery-public-data.samples.natality`
245-
GROUP BY year
246-
ORDER BY year DESC
247-
LIMIT 15
248-
"""
249-
)
274+
client.close()
250275

251-
client.close()
276+
import gc
252277

253-
conn_count_end = len(current_process.net_connections())
278+
gc.collect()
279+
conn_end = current_process.net_connections()
280+
conn_count_end = len(conn_end)
281+
if conn_count_end > conn_count_start:
282+
print("DEBUG: test_close_releases_open_sockets failed!")
283+
print(f"DEBUG: Start connections ({conn_count_start}): {conn_start}")
284+
print(f"DEBUG: End connections ({conn_count_end}): {conn_end}")
285+
print(f"DEBUG: Tracked requests: {tracked_requests}")
254286
self.assertLessEqual(conn_count_end, conn_count_start)
255287

256288
def test_create_dataset(self):
@@ -2174,25 +2206,36 @@ def test_dbapi_dry_run_query(self):
21742206
def test_dbapi_connection_does_not_leak_sockets(self):
21752207
pytest.importorskip("google.cloud.bigquery_storage")
21762208
current_process = psutil.Process()
2177-
conn_count_start = len(current_process.net_connections())
2178-
2179-
# Provide no explicit clients, so that the connection will create and own them.
2180-
connection = dbapi.connect()
2181-
cursor = connection.cursor()
2182-
2183-
cursor.execute(
2209+
conn_start = current_process.net_connections()
2210+
conn_count_start = len(conn_start)
2211+
2212+
with patch_tracked_requests() as tracked_requests:
2213+
# Provide no explicit clients, so that the connection will create and own them.
2214+
connection = dbapi.connect()
2215+
cursor = connection.cursor()
2216+
2217+
cursor.execute(
2218+
"""
2219+
SELECT id, `by`, timestamp
2220+
FROM `bigquery-public-data.hacker_news.full`
2221+
ORDER BY `id` ASC
2222+
LIMIT 100000
21842223
"""
2185-
SELECT id, `by`, timestamp
2186-
FROM `bigquery-public-data.hacker_news.full`
2187-
ORDER BY `id` ASC
2188-
LIMIT 100000
2189-
"""
2190-
)
2191-
rows = cursor.fetchall()
2192-
self.assertEqual(len(rows), 100000)
2193-
2194-
connection.close()
2195-
conn_count_end = len(current_process.net_connections())
2224+
)
2225+
rows = cursor.fetchall()
2226+
self.assertEqual(len(rows), 100000)
2227+
2228+
connection.close()
2229+
import gc
2230+
2231+
gc.collect()
2232+
conn_end = current_process.net_connections()
2233+
conn_count_end = len(conn_end)
2234+
if conn_count_end > conn_count_start:
2235+
print("DEBUG: test_dbapi_connection_does_not_leak_sockets failed!")
2236+
print(f"DEBUG: Start connections ({conn_count_start}): {conn_start}")
2237+
print(f"DEBUG: End connections ({conn_count_end}): {conn_end}")
2238+
print(f"DEBUG: Tracked requests: {tracked_requests}")
21962239
self.assertLessEqual(conn_count_end, conn_count_start)
21972240

21982241
def _load_table_for_dml(self, rows, dataset_id, table_id):

packages/google-cloud-bigquery/tests/system/test_magics.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515
"""System tests for Jupyter/IPython connector."""
1616

17+
import contextlib
1718
import re
1819

20+
import google.auth.transport.requests
21+
1922
import pytest
2023
import psutil
2124

@@ -45,30 +48,52 @@ def ipython_interactive(ipython):
4548
yield ipython
4649

4750

51+
@contextlib.contextmanager
52+
def patch_tracked_requests():
53+
original_init = google.auth.transport.requests.Request.__init__
54+
tracked_requests = []
55+
56+
def patched_init(self, session=None):
57+
original_init(self, session=session)
58+
tracked_requests.append(self)
59+
60+
google.auth.transport.requests.Request.__init__ = patched_init
61+
try:
62+
yield tracked_requests
63+
finally:
64+
google.auth.transport.requests.Request.__init__ = original_init
65+
for req in tracked_requests:
66+
if hasattr(req, "session") and req.session is not None:
67+
req.session.close()
68+
69+
4870
def test_bigquery_magic(ipython_interactive):
4971
ip = IPython.get_ipython()
5072
current_process = psutil.Process()
51-
conn_count_start = len(current_process.net_connections())
52-
53-
# Deprecated, but should still work in google-cloud-bigquery 3.x.
54-
with pytest.warns(FutureWarning, match="bigquery_magics"):
55-
ip.extension_manager.load_extension("google.cloud.bigquery")
56-
57-
sql = """
58-
SELECT
59-
CONCAT(
60-
'https://stackoverflow.com/questions/',
61-
CAST(id as STRING)) as url,
62-
view_count
63-
FROM `bigquery-public-data.stackoverflow.posts_questions`
64-
WHERE tags like '%google-bigquery%'
65-
ORDER BY view_count DESC
66-
LIMIT 10
67-
"""
68-
with io.capture_output() as captured:
69-
result = ip.run_cell_magic("bigquery", "--use_rest_api", sql)
70-
71-
conn_count_end = len(current_process.net_connections())
73+
conn_start = current_process.net_connections()
74+
conn_count_start = len(conn_start)
75+
76+
with patch_tracked_requests() as tracked_requests:
77+
# Deprecated, but should still work in google-cloud-bigquery 3.x.
78+
with pytest.warns(FutureWarning, match="bigquery_magics"):
79+
ip.extension_manager.load_extension("google.cloud.bigquery")
80+
81+
sql = """
82+
SELECT
83+
CONCAT(
84+
'https://stackoverflow.com/questions/',
85+
CAST(id as STRING)) as url,
86+
view_count
87+
FROM `bigquery-public-data.stackoverflow.posts_questions`
88+
WHERE tags like '%google-bigquery%'
89+
ORDER BY view_count DESC
90+
LIMIT 10
91+
"""
92+
with io.capture_output() as captured:
93+
result = ip.run_cell_magic("bigquery", "--use_rest_api", sql)
94+
95+
conn_end = current_process.net_connections()
96+
conn_count_end = len(conn_end)
7297

7398
lines = re.split("\n|\r", captured.stdout)
7499
# Removes blanks & terminal code (result of display clearing)
@@ -82,4 +107,9 @@ def test_bigquery_magic(ipython_interactive):
82107
# NOTE: For some reason, the number of open sockets is sometimes one *less*
83108
# than expected when running system tests on Kokoro, thus using the <= assertion.
84109
# That's still fine, however, since the sockets are apparently not leaked.
110+
if conn_count_end > conn_count_start:
111+
print("DEBUG: test_bigquery_magic failed!")
112+
print(f"DEBUG: Start connections ({conn_count_start}): {conn_start}")
113+
print(f"DEBUG: End connections ({conn_count_end}): {conn_end}")
114+
print(f"DEBUG: Tracked requests: {tracked_requests}")
85115
assert conn_count_end <= conn_count_start # system resources are released

packages/google-cloud-bigquery/tests/unit/test_magics.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
@pytest.fixture()
4747
def use_local_magics_context(monkeypatch):
48-
if magics is not None:
48+
if magics is not None: # pragma: NO COVER
4949
local_context = magics.Context()
5050
local_context._project = "unit-test-project"
5151
mock_credentials = mock.create_autospec(
@@ -2195,13 +2195,10 @@ def test_bigquery_magic_create_dataset_fails(monkeypatch):
21952195

21962196

21972197
@pytest.mark.usefixtures("ipython_interactive")
2198-
def test_bigquery_magic_with_location(monkeypatch):
2198+
def test_bigquery_magic_with_location(monkeypatch, use_local_magics_context):
21992199
ip = IPython.get_ipython()
22002200
monkeypatch.setattr(bigquery, "bigquery_magics", None)
22012201
bigquery.load_ipython_extension(ip)
2202-
magics.context.credentials = mock.create_autospec(
2203-
google.auth.credentials.Credentials, instance=True
2204-
)
22052202

22062203
run_query_patch = mock.patch(
22072204
"google.cloud.bigquery.magics.magics._run_query", autospec=True

packages/google-cloud-bigquery/tests/unit/test_table.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,82 @@ def test_ctor_with_key(self):
7979
self.assertEqual(encryption_config.kms_key_name, self.KMS_KEY_NAME)
8080

8181

82+
class TestPropertyGraphReference(unittest.TestCase):
83+
PROJECT = "my-project"
84+
DATASET_ID = "my_dataset"
85+
PROPERTY_GRAPH_ID = "my_pg"
86+
87+
def _get_target_class(self):
88+
from google.cloud.bigquery.table import PropertyGraphReference
89+
90+
return PropertyGraphReference
91+
92+
def _make_one(self, *args, **kw):
93+
return self._get_target_class()(*args, **kw)
94+
95+
def test_ctor(self):
96+
dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID)
97+
ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID)
98+
self.assertEqual(ref.project, self.PROJECT)
99+
self.assertEqual(ref.dataset_id, self.DATASET_ID)
100+
self.assertEqual(ref.property_graph_id, self.PROPERTY_GRAPH_ID)
101+
102+
def test_from_api_repr(self):
103+
resource = {
104+
"projectId": self.PROJECT,
105+
"datasetId": self.DATASET_ID,
106+
"propertyGraphId": self.PROPERTY_GRAPH_ID,
107+
}
108+
ref = self._get_target_class().from_api_repr(resource)
109+
self.assertEqual(ref.project, self.PROJECT)
110+
self.assertEqual(ref.dataset_id, self.DATASET_ID)
111+
self.assertEqual(ref.property_graph_id, self.PROPERTY_GRAPH_ID)
112+
113+
def test_to_api_repr(self):
114+
dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID)
115+
ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID)
116+
resource = ref.to_api_repr()
117+
expected = {
118+
"projectId": self.PROJECT,
119+
"datasetId": self.DATASET_ID,
120+
"propertyGraphId": self.PROPERTY_GRAPH_ID,
121+
}
122+
self.assertEqual(resource, expected)
123+
124+
def test___str__(self):
125+
dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID)
126+
ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID)
127+
self.assertEqual(
128+
str(ref), f"{self.PROJECT}.{self.DATASET_ID}.{self.PROPERTY_GRAPH_ID}"
129+
)
130+
131+
def test___repr__(self):
132+
dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID)
133+
ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID)
134+
expected = (
135+
f"PropertyGraphReference({dataset_ref!r}, '{self.PROPERTY_GRAPH_ID}')"
136+
)
137+
self.assertEqual(repr(ref), expected)
138+
139+
def test___eq__(self):
140+
dataset_ref1 = DatasetReference(self.PROJECT, self.DATASET_ID)
141+
ref1 = self._make_one(dataset_ref1, self.PROPERTY_GRAPH_ID)
142+
dataset_ref2 = DatasetReference(self.PROJECT, self.DATASET_ID)
143+
ref2 = self._make_one(dataset_ref2, self.PROPERTY_GRAPH_ID)
144+
self.assertEqual(ref1, ref2)
145+
146+
ref3 = self._make_one(dataset_ref1, "other_pg")
147+
self.assertNotEqual(ref1, ref3)
148+
self.assertNotEqual(ref1, object())
149+
150+
def test___hash__(self):
151+
dataset_ref1 = DatasetReference(self.PROJECT, self.DATASET_ID)
152+
ref1 = self._make_one(dataset_ref1, self.PROPERTY_GRAPH_ID)
153+
dataset_ref2 = DatasetReference(self.PROJECT, self.DATASET_ID)
154+
ref2 = self._make_one(dataset_ref2, self.PROPERTY_GRAPH_ID)
155+
self.assertEqual(hash(ref1), hash(ref2))
156+
157+
82158
class TestTableBase:
83159
@staticmethod
84160
def _get_target_class():

0 commit comments

Comments
 (0)