From 5f5f6c9fbf32fbc15f0dbfbeb4881817109bd4f4 Mon Sep 17 00:00:00 2001 From: chalmer lowe Date: Wed, 17 Jun 2026 13:16:32 -0400 Subject: [PATCH 1/2] fix(bigquery): close GAPIC storage transport correctly to release sockets --- .../google/cloud/bigquery/dbapi/connection.py | 2 +- .../google/cloud/bigquery/magics/magics.py | 2 +- .../google/cloud/bigquery/table.py | 2 +- .../tests/unit/test_dbapi_connection.py | 6 +++--- .../google-cloud-bigquery/tests/unit/test_table.py | 14 +++++++------- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/packages/google-cloud-bigquery/google/cloud/bigquery/dbapi/connection.py b/packages/google-cloud-bigquery/google/cloud/bigquery/dbapi/connection.py index a1a69b8fec90..b0d7ef895141 100644 --- a/packages/google-cloud-bigquery/google/cloud/bigquery/dbapi/connection.py +++ b/packages/google-cloud-bigquery/google/cloud/bigquery/dbapi/connection.py @@ -84,7 +84,7 @@ def close(self): if self._owns_bqstorage_client: # There is no close() on the BQ Storage client itself. - self._bqstorage_client._transport.grpc_channel.close() + self._bqstorage_client._transport.close() for cursor_ in self._cursors_created: if not cursor_._closed: diff --git a/packages/google-cloud-bigquery/google/cloud/bigquery/magics/magics.py b/packages/google-cloud-bigquery/google/cloud/bigquery/magics/magics.py index 1f892b595222..30bc9d27a8b6 100644 --- a/packages/google-cloud-bigquery/google/cloud/bigquery/magics/magics.py +++ b/packages/google-cloud-bigquery/google/cloud/bigquery/magics/magics.py @@ -773,4 +773,4 @@ def _close_transports(client, bqstorage_client): """ client.close() if bqstorage_client is not None: - bqstorage_client._transport.grpc_channel.close() + bqstorage_client._transport.close() diff --git a/packages/google-cloud-bigquery/google/cloud/bigquery/table.py b/packages/google-cloud-bigquery/google/cloud/bigquery/table.py index b58499343b8a..5d28d837cafe 100644 --- a/packages/google-cloud-bigquery/google/cloud/bigquery/table.py +++ b/packages/google-cloud-bigquery/google/cloud/bigquery/table.py @@ -2353,7 +2353,7 @@ def to_arrow( progress_bar.close() finally: if owns_bqstorage_client: - bqstorage_client._transport.grpc_channel.close() # type: ignore + bqstorage_client._transport.close() if record_batches and bqstorage_client is not None: return pyarrow.Table.from_batches(record_batches) diff --git a/packages/google-cloud-bigquery/tests/unit/test_dbapi_connection.py b/packages/google-cloud-bigquery/tests/unit/test_dbapi_connection.py index f5c77c448eee..8047462243dd 100644 --- a/packages/google-cloud-bigquery/tests/unit/test_dbapi_connection.py +++ b/packages/google-cloud-bigquery/tests/unit/test_dbapi_connection.py @@ -40,7 +40,7 @@ def _mock_bqstorage_client(self): from google.cloud import bigquery_storage mock_client = mock.create_autospec(bigquery_storage.BigQueryReadClient) - mock_client._transport = mock.Mock(spec=["channel"]) + mock_client._transport = mock.Mock(spec=["channel", "close"]) mock_client._transport.grpc_channel = mock.Mock(spec=["close"]) return mock_client @@ -176,7 +176,7 @@ def test_close_closes_all_created_bigquery_clients(self): connection.close() self.assertTrue(client.close.called) - self.assertTrue(bqstorage_client._transport.grpc_channel.close.called) + self.assertTrue(bqstorage_client._transport.close.called) def test_close_does_not_close_bigquery_clients_passed_to_it(self): pytest.importorskip("google.cloud.bigquery_storage") @@ -187,7 +187,7 @@ def test_close_does_not_close_bigquery_clients_passed_to_it(self): connection.close() self.assertFalse(client.close.called) - self.assertFalse(bqstorage_client._transport.grpc_channel.close.called) + self.assertFalse(bqstorage_client._transport.close.called) def test_close_closes_all_created_cursors(self): connection = self._make_one(client=self._mock_client()) diff --git a/packages/google-cloud-bigquery/tests/unit/test_table.py b/packages/google-cloud-bigquery/tests/unit/test_table.py index 0297156aef95..6a2415f6d604 100644 --- a/packages/google-cloud-bigquery/tests/unit/test_table.py +++ b/packages/google-cloud-bigquery/tests/unit/test_table.py @@ -3048,7 +3048,7 @@ def test_to_arrow_iterable_w_bqstorage(self): self.assertEqual(record_batch, expected_record_batch) # Don't close the client if it was passed in. - bqstorage_client._transport.grpc_channel.close.assert_not_called() + bqstorage_client._transport.close.assert_not_called() def test_to_arrow(self): pytest.importorskip("numpy") @@ -3424,7 +3424,7 @@ def test_to_arrow_w_bqstorage(self): self.assertEqual(actual_tbl.num_rows, total_rows) # Don't close the client if it was passed in. - bqstorage_client._transport.grpc_channel.close.assert_not_called() + bqstorage_client._transport.close.assert_not_called() def test_to_arrow_w_bqstorage_creates_client(self): pytest.importorskip("numpy") @@ -3458,7 +3458,7 @@ def test_to_arrow_w_bqstorage_creates_client(self): ) row_iterator.to_arrow(create_bqstorage_client=True) mock_client._ensure_bqstorage_client.assert_called_once() - bqstorage_client._transport.grpc_channel.close.assert_called_once() + bqstorage_client._transport.close.assert_called_once() def test_to_arrow_ensure_bqstorage_client_wo_bqstorage(self): pytest.importorskip("numpy") @@ -3741,7 +3741,7 @@ def test_to_dataframe_iterable_w_bqstorage(self): self.assertEqual(len(got), total_pages) # Don't close the client if it was passed in. - bqstorage_client._transport.grpc_channel.close.assert_not_called() + bqstorage_client._transport.close.assert_not_called() def test_to_dataframe_iterable_w_bqstorage_max_results_warning(self): pytest.importorskip("numpy") @@ -4807,7 +4807,7 @@ def test_to_dataframe_w_bqstorage_creates_client(self): ) row_iterator.to_dataframe(create_bqstorage_client=True) mock_client._ensure_bqstorage_client.assert_called_once() - bqstorage_client._transport.grpc_channel.close.assert_called_once() + bqstorage_client._transport.close.assert_called_once() def test_to_dataframe_w_bqstorage_no_streams(self): pytest.importorskip("numpy") @@ -4999,7 +4999,7 @@ def test_to_dataframe_w_bqstorage_nonempty(self): self.assertEqual(len(got.index), total_rows) # Don't close the client if it was passed in. - bqstorage_client._transport.grpc_channel.close.assert_not_called() + bqstorage_client._transport.close.assert_not_called() def test_to_dataframe_w_bqstorage_multiple_streams_return_unique_index(self): pytest.importorskip("numpy") @@ -5421,7 +5421,7 @@ def test_to_dataframe_concat_categorical_dtype_w_pyarrow(self): ) # Don't close the client if it was passed in. - bqstorage_client._transport.grpc_channel.close.assert_not_called() + bqstorage_client._transport.close.assert_not_called() def test_to_dataframe_geography_as_object(self): pandas = pytest.importorskip("pandas") From 43e6eece40f53c0e5747e87cb4c51d6e9c6b8b1b Mon Sep 17 00:00:00 2001 From: chalmer lowe Date: Wed, 17 Jun 2026 18:45:29 -0400 Subject: [PATCH 2/2] test(bigquery): refactor request interceptor to resolve socket leaks in system tests --- .../google/cloud/bigquery/__init__.py | 2 +- .../google/cloud/bigquery/table.py | 4 +- .../tests/system/helpers.py | 28 ++++++- .../tests/system/test_client.py | 75 ++++++++++-------- .../tests/system/test_magics.py | 46 ++++++----- .../tests/unit/test_magics.py | 7 +- .../tests/unit/test_table.py | 76 +++++++++++++++++++ 7 files changed, 177 insertions(+), 61 deletions(-) diff --git a/packages/google-cloud-bigquery/google/cloud/bigquery/__init__.py b/packages/google-cloud-bigquery/google/cloud/bigquery/__init__.py index d20e288f6ac3..fa03156e2e26 100644 --- a/packages/google-cloud-bigquery/google/cloud/bigquery/__init__.py +++ b/packages/google-cloud-bigquery/google/cloud/bigquery/__init__.py @@ -123,7 +123,7 @@ except ImportError: bigquery_magics = None -if sys.version_info < (3, 10): +if sys.version_info < (3, 10): # pragma: NO COVER warnings.warn( "The python-bigquery library no longer supports Python <= 3.9. " f"Your Python version is {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}. We " diff --git a/packages/google-cloud-bigquery/google/cloud/bigquery/table.py b/packages/google-cloud-bigquery/google/cloud/bigquery/table.py index 5d28d837cafe..870cdcc5d2ab 100644 --- a/packages/google-cloud-bigquery/google/cloud/bigquery/table.py +++ b/packages/google-cloud-bigquery/google/cloud/bigquery/table.py @@ -2353,7 +2353,9 @@ def to_arrow( progress_bar.close() finally: if owns_bqstorage_client: - bqstorage_client._transport.close() + # mypy: bqstorage_client is guaranteed to be not None when owns_bqstorage_client is True, + # but mypy cannot infer this correlation. We ignore the union-attr error here. + bqstorage_client._transport.close() # type: ignore[union-attr] if record_batches and bqstorage_client is not None: return pyarrow.Table.from_batches(record_batches) diff --git a/packages/google-cloud-bigquery/tests/system/helpers.py b/packages/google-cloud-bigquery/tests/system/helpers.py index 7fd344eeb071..6a8e142c2a50 100644 --- a/packages/google-cloud-bigquery/tests/system/helpers.py +++ b/packages/google-cloud-bigquery/tests/system/helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import datetime import decimal import uuid @@ -21,7 +22,6 @@ from google.cloud._helpers import UTC - _naive = datetime.datetime(2016, 12, 5, 12, 41, 9) _naive_microseconds = datetime.datetime(2016, 12, 5, 12, 41, 9, 250000) _stamp = "%s %s" % (_naive.date().isoformat(), _naive.time().isoformat()) @@ -104,3 +104,29 @@ def _rate_limit_exceeded(forbidden): google.api_core.exceptions.Forbidden, error_predicate=_rate_limit_exceeded, ) + + +@contextlib.contextmanager +def patch_tracked_requests(): + """Context manager to patch google-auth requests and track/close their HTTP sessions. + + This prevents socket leaks in system tests that use Workload Identity or metadata server auth. + """ + import google.auth.transport.requests + + original_init = google.auth.transport.requests.Request.__init__ + tracked_requests = [] + + def patched_init(self, session=None): + original_init(self, session=session) + if session is None: + tracked_requests.append(self) + + google.auth.transport.requests.Request.__init__ = patched_init + try: + yield tracked_requests + finally: + google.auth.transport.requests.Request.__init__ = original_init + for req in tracked_requests: + if hasattr(req, "session") and req.session is not None: + req.session.close() diff --git a/packages/google-cloud-bigquery/tests/system/test_client.py b/packages/google-cloud-bigquery/tests/system/test_client.py index b6da77c04bdb..d5ec07b5a557 100644 --- a/packages/google-cloud-bigquery/tests/system/test_client.py +++ b/packages/google-cloud-bigquery/tests/system/test_client.py @@ -58,7 +58,6 @@ from . import helpers - JOB_TIMEOUT = 120 # 2 minutes DATA_PATH = pathlib.Path(__file__).parent.parent / "data" @@ -234,23 +233,29 @@ def _create_bucket(self, bucket_name, location=None): def test_close_releases_open_sockets(self): current_process = psutil.Process() - conn_count_start = len(current_process.net_connections()) + conn_start = current_process.net_connections() + conn_count_start = len(conn_start) + + with helpers.patch_tracked_requests(): + client = Config.CLIENT + client.query( + """ + SELECT + source_year AS year, COUNT(is_male) AS birth_count + FROM `bigquery-public-data.samples.natality` + GROUP BY year + ORDER BY year DESC + LIMIT 15 + """ + ) - client = Config.CLIENT - client.query( - """ - SELECT - source_year AS year, COUNT(is_male) AS birth_count - FROM `bigquery-public-data.samples.natality` - GROUP BY year - ORDER BY year DESC - LIMIT 15 - """ - ) + client.close() - client.close() + import gc - conn_count_end = len(current_process.net_connections()) + gc.collect() + conn_end = current_process.net_connections() + conn_count_end = len(conn_end) self.assertLessEqual(conn_count_end, conn_count_start) def test_create_dataset(self): @@ -2174,25 +2179,31 @@ def test_dbapi_dry_run_query(self): def test_dbapi_connection_does_not_leak_sockets(self): pytest.importorskip("google.cloud.bigquery_storage") current_process = psutil.Process() - conn_count_start = len(current_process.net_connections()) - - # Provide no explicit clients, so that the connection will create and own them. - connection = dbapi.connect() - cursor = connection.cursor() - - cursor.execute( + conn_start = current_process.net_connections() + conn_count_start = len(conn_start) + + with helpers.patch_tracked_requests(): + # Provide no explicit clients, so that the connection will create and own them. + connection = dbapi.connect() + cursor = connection.cursor() + + cursor.execute( + """ + SELECT id, `by`, timestamp + FROM `bigquery-public-data.hacker_news.full` + ORDER BY `id` ASC + LIMIT 100000 """ - SELECT id, `by`, timestamp - FROM `bigquery-public-data.hacker_news.full` - ORDER BY `id` ASC - LIMIT 100000 - """ - ) - rows = cursor.fetchall() - self.assertEqual(len(rows), 100000) + ) + rows = cursor.fetchall() + self.assertEqual(len(rows), 100000) + + connection.close() + import gc - connection.close() - conn_count_end = len(current_process.net_connections()) + gc.collect() + conn_end = current_process.net_connections() + conn_count_end = len(conn_end) self.assertLessEqual(conn_count_end, conn_count_start) def _load_table_for_dml(self, rows, dataset_id, table_id): diff --git a/packages/google-cloud-bigquery/tests/system/test_magics.py b/packages/google-cloud-bigquery/tests/system/test_magics.py index d40b18663ef2..31fd4543eed5 100644 --- a/packages/google-cloud-bigquery/tests/system/test_magics.py +++ b/packages/google-cloud-bigquery/tests/system/test_magics.py @@ -19,6 +19,7 @@ import pytest import psutil +from . import helpers IPython = pytest.importorskip("IPython") io = pytest.importorskip("IPython.utils.io") @@ -48,27 +49,30 @@ def ipython_interactive(ipython): def test_bigquery_magic(ipython_interactive): ip = IPython.get_ipython() current_process = psutil.Process() - conn_count_start = len(current_process.net_connections()) - - # Deprecated, but should still work in google-cloud-bigquery 3.x. - with pytest.warns(FutureWarning, match="bigquery_magics"): - ip.extension_manager.load_extension("google.cloud.bigquery") - - sql = """ - SELECT - CONCAT( - 'https://stackoverflow.com/questions/', - CAST(id as STRING)) as url, - view_count - FROM `bigquery-public-data.stackoverflow.posts_questions` - WHERE tags like '%google-bigquery%' - ORDER BY view_count DESC - LIMIT 10 - """ - with io.capture_output() as captured: - result = ip.run_cell_magic("bigquery", "--use_rest_api", sql) - - conn_count_end = len(current_process.net_connections()) + conn_start = current_process.net_connections() + conn_count_start = len(conn_start) + + with helpers.patch_tracked_requests(): + # Deprecated, but should still work in google-cloud-bigquery 3.x. + with pytest.warns(FutureWarning, match="bigquery_magics"): + ip.extension_manager.load_extension("google.cloud.bigquery") + + sql = """ + SELECT + CONCAT( + 'https://stackoverflow.com/questions/', + CAST(id as STRING)) as url, + view_count + FROM `bigquery-public-data.stackoverflow.posts_questions` + WHERE tags like '%google-bigquery%' + ORDER BY view_count DESC + LIMIT 10 + """ + with io.capture_output() as captured: + result = ip.run_cell_magic("bigquery", "--use_rest_api", sql) + + conn_end = current_process.net_connections() + conn_count_end = len(conn_end) lines = re.split("\n|\r", captured.stdout) # Removes blanks & terminal code (result of display clearing) diff --git a/packages/google-cloud-bigquery/tests/unit/test_magics.py b/packages/google-cloud-bigquery/tests/unit/test_magics.py index f679d2806bc1..74b8265b3c56 100644 --- a/packages/google-cloud-bigquery/tests/unit/test_magics.py +++ b/packages/google-cloud-bigquery/tests/unit/test_magics.py @@ -45,7 +45,7 @@ @pytest.fixture() def use_local_magics_context(monkeypatch): - if magics is not None: + if magics is not None: # pragma: NO COVER local_context = magics.Context() local_context._project = "unit-test-project" mock_credentials = mock.create_autospec( @@ -2195,13 +2195,10 @@ def test_bigquery_magic_create_dataset_fails(monkeypatch): @pytest.mark.usefixtures("ipython_interactive") -def test_bigquery_magic_with_location(monkeypatch): +def test_bigquery_magic_with_location(monkeypatch, use_local_magics_context): ip = IPython.get_ipython() monkeypatch.setattr(bigquery, "bigquery_magics", None) bigquery.load_ipython_extension(ip) - magics.context.credentials = mock.create_autospec( - google.auth.credentials.Credentials, instance=True - ) run_query_patch = mock.patch( "google.cloud.bigquery.magics.magics._run_query", autospec=True diff --git a/packages/google-cloud-bigquery/tests/unit/test_table.py b/packages/google-cloud-bigquery/tests/unit/test_table.py index 6a2415f6d604..5701143a62d4 100644 --- a/packages/google-cloud-bigquery/tests/unit/test_table.py +++ b/packages/google-cloud-bigquery/tests/unit/test_table.py @@ -79,6 +79,82 @@ def test_ctor_with_key(self): self.assertEqual(encryption_config.kms_key_name, self.KMS_KEY_NAME) +class TestPropertyGraphReference(unittest.TestCase): + PROJECT = "my-project" + DATASET_ID = "my_dataset" + PROPERTY_GRAPH_ID = "my_pg" + + def _get_target_class(self): + from google.cloud.bigquery.table import PropertyGraphReference + + return PropertyGraphReference + + def _make_one(self, *args, **kw): + return self._get_target_class()(*args, **kw) + + def test_ctor(self): + dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID) + ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID) + self.assertEqual(ref.project, self.PROJECT) + self.assertEqual(ref.dataset_id, self.DATASET_ID) + self.assertEqual(ref.property_graph_id, self.PROPERTY_GRAPH_ID) + + def test_from_api_repr(self): + resource = { + "projectId": self.PROJECT, + "datasetId": self.DATASET_ID, + "propertyGraphId": self.PROPERTY_GRAPH_ID, + } + ref = self._get_target_class().from_api_repr(resource) + self.assertEqual(ref.project, self.PROJECT) + self.assertEqual(ref.dataset_id, self.DATASET_ID) + self.assertEqual(ref.property_graph_id, self.PROPERTY_GRAPH_ID) + + def test_to_api_repr(self): + dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID) + ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID) + resource = ref.to_api_repr() + expected = { + "projectId": self.PROJECT, + "datasetId": self.DATASET_ID, + "propertyGraphId": self.PROPERTY_GRAPH_ID, + } + self.assertEqual(resource, expected) + + def test___str__(self): + dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID) + ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID) + self.assertEqual( + str(ref), f"{self.PROJECT}.{self.DATASET_ID}.{self.PROPERTY_GRAPH_ID}" + ) + + def test___repr__(self): + dataset_ref = DatasetReference(self.PROJECT, self.DATASET_ID) + ref = self._make_one(dataset_ref, self.PROPERTY_GRAPH_ID) + expected = ( + f"PropertyGraphReference({dataset_ref!r}, '{self.PROPERTY_GRAPH_ID}')" + ) + self.assertEqual(repr(ref), expected) + + def test___eq__(self): + dataset_ref1 = DatasetReference(self.PROJECT, self.DATASET_ID) + ref1 = self._make_one(dataset_ref1, self.PROPERTY_GRAPH_ID) + dataset_ref2 = DatasetReference(self.PROJECT, self.DATASET_ID) + ref2 = self._make_one(dataset_ref2, self.PROPERTY_GRAPH_ID) + self.assertEqual(ref1, ref2) + + ref3 = self._make_one(dataset_ref1, "other_pg") + self.assertNotEqual(ref1, ref3) + self.assertNotEqual(ref1, object()) + + def test___hash__(self): + dataset_ref1 = DatasetReference(self.PROJECT, self.DATASET_ID) + ref1 = self._make_one(dataset_ref1, self.PROPERTY_GRAPH_ID) + dataset_ref2 = DatasetReference(self.PROJECT, self.DATASET_ID) + ref2 = self._make_one(dataset_ref2, self.PROPERTY_GRAPH_ID) + self.assertEqual(hash(ref1), hash(ref2)) + + class TestTableBase: @staticmethod def _get_target_class():