diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py index 3e38c4e0191d..9154c5c299b1 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_interceptor.py @@ -15,7 +15,7 @@ """Interceptor for collecting Cloud Spanner metrics.""" import re -from typing import Dict +from typing import Any, Dict from grpc_interceptor import ClientInterceptor @@ -122,10 +122,115 @@ def intercept(self, invoked_method, request_or_iterator, call_details): tracer.set_method(method_name) tracer.record_attempt_start() response = invoked_method(request_or_iterator, call_details) - tracer.record_attempt_completion() - # Process and send GFE metrics if enabled - if tracer.gfe_enabled: - metadata = response.initial_metadata() - tracer.record_gfe_metrics(metadata) + return _wrap_response(response, tracer) + + +def _wrap_response(response: Any, tracer: Any) -> Any: + """Wraps the response if it is streaming, or records metrics immediately if unary.""" + if hasattr(response, "__anext__") or hasattr(response, "__aiter__"): + return _AsyncStreamingResponseWrapper(response, tracer) + elif hasattr(response, "__next__") or hasattr(response, "__iter__"): + return _StreamingResponseWrapper(response, tracer) + else: + # Unary call: execute completion and record metrics immediately + tracer.record_attempt_completion() + metadata = [] + if hasattr(response, "initial_metadata"): + try: + metadata.extend(response.initial_metadata() or []) + except Exception: + pass + if hasattr(response, "trailing_metadata"): + try: + metadata.extend(response.trailing_metadata() or []) + except Exception: + pass + tracer.record_gfe_metrics(metadata) return response + + +class _StreamingResponseWrapper: + """Wrapper for streaming RPC response iterators to defer metrics recording.""" + + def __init__(self, response, tracer): + self._response = response + self._tracer = tracer + self._metrics_recorded = False + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self._response) + except StopIteration: + self._record_metrics() + raise + except Exception: + self._record_metrics() + raise + + def _record_metrics(self): + if self._metrics_recorded: + return + self._metrics_recorded = True + self._tracer.record_attempt_completion() + metadata = [] + if hasattr(self._response, "initial_metadata"): + try: + metadata.extend(self._response.initial_metadata() or []) + except Exception: + pass + if hasattr(self._response, "trailing_metadata"): + try: + metadata.extend(self._response.trailing_metadata() or []) + except Exception: + pass + self._tracer.record_gfe_metrics(metadata) + + def __getattr__(self, name): + return getattr(self._response, name) + + +class _AsyncStreamingResponseWrapper: + """Wrapper for async streaming RPC response iterators to defer metrics recording.""" + + def __init__(self, response, tracer): + self._response = response + self._tracer = tracer + self._metrics_recorded = False + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self._response.__anext__() + except StopAsyncIteration: + self._record_metrics() + raise + except Exception: + self._record_metrics() + raise + + def _record_metrics(self): + if self._metrics_recorded: + return + self._metrics_recorded = True + self._tracer.record_attempt_completion() + metadata = [] + if hasattr(self._response, "initial_metadata"): + try: + metadata.extend(self._response.initial_metadata() or []) + except Exception: + pass + if hasattr(self._response, "trailing_metadata"): + try: + metadata.extend(self._response.trailing_metadata() or []) + except Exception: + pass + self._tracer.record_gfe_metrics(metadata) + + def __getattr__(self, name): + return getattr(self._response, name) diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py index f79869948f99..39fabedb703e 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer.py @@ -19,8 +19,9 @@ while the helper classes provide additional functionality and context for the metrics being traced. """ +import re from datetime import datetime -from typing import Dict +from typing import Any, Dict, Optional from grpc import StatusCode @@ -198,6 +199,8 @@ def __init__( instrument_operation_counter: "Counter", client_attributes: Dict[str, str], gfe_enabled: bool = False, + instrument_gfe_latency: Optional["Histogram"] = None, + instrument_gfe_missing_header_count: Optional["Counter"] = None, ): """ Initialize a MetricsTracer instance with the given parameters. @@ -214,6 +217,8 @@ def __init__( instrument_operation_counter (Counter): Instrument for counting operations. client_attributes (Dict[str, str]): Dictionary of client attributes used for metrics tracing. gfe_enabled (bool, optional): Indicates if GFE metrics are enabled. Defaults to False. + instrument_gfe_latency (Histogram, optional): Instrument for measuring GFE latency. + instrument_gfe_missing_header_count (Counter, optional): Instrument for counting missing GFE headers. """ self.current_op = MetricOpTracer() self._client_attributes = client_attributes @@ -221,8 +226,10 @@ def __init__( self._instrument_attempt_counter = instrument_attempt_counter self._instrument_operation_latency = instrument_operation_latency self._instrument_operation_counter = instrument_operation_counter + self._instrument_gfe_latency = instrument_gfe_latency + self._instrument_gfe_missing_header_count = instrument_gfe_missing_header_count self.enabled = enabled - self.gfe_enabled = gfe_enabled + self.gfe_enabled = True @staticmethod def _get_ms_time_diff(start: datetime, end: datetime) -> float: @@ -399,7 +406,11 @@ def record_gfe_latency(self, latency: int) -> None: Args: latency (int): The latency duration to be recorded. """ - if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED or not self.gfe_enabled: + if ( + not self.enabled + or not HAS_OPENTELEMETRY_INSTALLED + or not getattr(self, "_instrument_gfe_latency", None) + ): return self._instrument_gfe_latency.record( amount=latency, attributes=self.client_attributes @@ -409,12 +420,65 @@ def record_gfe_missing_header_count(self) -> None: """ Increments the counter for missing GFE headers. """ - if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED or not self.gfe_enabled: + if ( + not self.enabled + or not HAS_OPENTELEMETRY_INSTALLED + or not getattr(self, "_instrument_gfe_missing_header_count", None) + ): return self._instrument_gfe_missing_header_count.add( amount=1, attributes=self.client_attributes ) + @staticmethod + def extract_gfe_latency(metadata: Any) -> Optional[int]: + """ + Extracts the GFE latency value (in milliseconds) from response metadata. + """ + if not metadata: + return None + + header_vals = [] + if isinstance(metadata, dict): + for key, val in metadata.items(): + if key and str(key).lower() in ("server-timing", "server_timing"): + if isinstance(val, (list, tuple)): + header_vals.extend(val) + else: + header_vals.append(val) + elif isinstance(metadata, (list, tuple)): + for key, val in metadata: + if key and str(key).lower() in ("server-timing", "server_timing"): + if isinstance(val, (list, tuple)): + header_vals.extend(val) + else: + header_vals.append(val) + + for header_val in header_vals: + if not header_val: + continue + if not isinstance(header_val, str): + header_val = str(header_val) + match = re.search(r"gfet4t7;\s*dur=([0-9]+)", header_val) + if match: + try: + return int(match.group(1)) + except ValueError: + pass + return None + + def record_gfe_metrics(self, metadata: Any) -> None: + """ + Extracts and records GFE metrics from the RPC response metadata. + """ + if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED: + return + latency = self.extract_gfe_latency(metadata) + if latency is not None: + self.record_gfe_latency(latency) + else: + self.record_gfe_missing_header_count() + def _create_operation_otel_attributes(self) -> dict: """ Create additional attributes for operation metrics tracing. diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py index f22d285c9750..029dddfaa15a 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/metrics_tracer_factory.py @@ -85,6 +85,7 @@ def __init__(self, enabled: bool, service_name: str): project (str): The project ID for the monitored resource. """ self.enabled = enabled + self.gfe_enabled = True self._create_metric_instruments(service_name) self._client_attributes = {} @@ -268,6 +269,11 @@ def create_metrics_tracer(self) -> MetricsTracer: instrument_operation_latency=self._instrument_operation_latency, instrument_operation_counter=self._instrument_operation_counter, client_attributes=self._client_attributes.copy(), + gfe_enabled=True, + instrument_gfe_latency=getattr(self, "_instrument_gfe_latency", None), + instrument_gfe_missing_header_count=getattr( + self, "_instrument_gfe_missing_header_count", None + ), ) return metrics_tracer diff --git a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py index 6fc5956582c1..7886e555f120 100644 --- a/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py +++ b/packages/google-cloud-spanner/google/cloud/spanner_v1/metrics/spanner_metrics_tracer_factory.py @@ -51,9 +51,7 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory): "current_metrics_tracer", default=None ) - def __new__( - cls, enabled: bool = True, gfe_enabled: bool = False - ) -> "SpannerMetricsTracerFactory": + def __new__(cls, enabled: bool = True) -> "SpannerMetricsTracerFactory": """ Create a new instance of SpannerMetricsTracerFactory if it doesn't already exist. @@ -63,7 +61,6 @@ def __new__( Args: enabled (bool): A flag indicating whether metrics tracing is enabled. Defaults to True. - gfe_enabled (bool): A flag indicating whether GFE metrics are enabled. Defaults to False. Returns: SpannerMetricsTracerFactory: The singleton instance of SpannerMetricsTracerFactory. @@ -83,7 +80,7 @@ def __new__( cls._generate_client_hash(client_uid) ) cls._metrics_tracer_factory.set_location(_get_cloud_region()) - cls._metrics_tracer_factory.gfe_enabled = gfe_enabled + cls._metrics_tracer_factory.gfe_enabled = True if cls._metrics_tracer_factory.enabled != enabled: cls._metrics_tracer_factory.enabled = enabled diff --git a/packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py b/packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py new file mode 100644 index 000000000000..6d4cf51bff4e --- /dev/null +++ b/packages/google-cloud-spanner/tests/mockserver_tests/test_gfe_metrics.py @@ -0,0 +1,206 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +import grpc +from google.api_core.client_options import ClientOptions +from google.auth.credentials import AnonymousCredentials +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader + +import google.cloud.spanner_v1.client as client_mod +from google.cloud.spanner_v1 import Client +from google.cloud.spanner_v1.metrics.metrics_interceptor import MetricsInterceptor +from google.cloud.spanner_v1.metrics.spanner_metrics_tracer_factory import ( + SpannerMetricsTracerFactory, +) +from google.cloud.spanner_v1.pool import FixedSizePool +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_select1_result, +) + + +class TestGFEMetricsIntegration(MockServerTestBase): + def setUp(self): + super().setUp() + os.environ["SPANNER_DISABLE_BUILTIN_METRICS"] = "false" + SpannerMetricsTracerFactory._metrics_tracer_factory = None + client_mod._metrics_monitor_initialized = False + + def tearDown(self): + super().tearDown() + os.environ["SPANNER_DISABLE_BUILTIN_METRICS"] = "true" + SpannerMetricsTracerFactory._metrics_tracer_factory = None + client_mod._metrics_monitor_initialized = False + + def test_gfe_metrics_exported(self): + add_select1_result() + reader = InMemoryMetricReader() + meter_provider = MeterProvider(metric_readers=[reader]) + + orig_call = grpc._channel._UnaryStreamMultiCallable.__call__ + orig_initial_metadata = grpc._channel._MultiThreadedRendezvous.initial_metadata + orig_trailing_metadata = ( + grpc._channel._MultiThreadedRendezvous.trailing_metadata + ) + + def custom_initial_metadata(self): + mocked = getattr(self, "_is_execute_streaming_sql_mock", False) + if mocked: + return (("server-timing", "gfet4t7; dur=55"),) + return orig_initial_metadata(self) + + def custom_trailing_metadata(self): + mocked = getattr(self, "_is_execute_streaming_sql_mock", False) + if mocked: + return (("server-timing", "gfet4t7; dur=55"),) + return orig_trailing_metadata(self) + + def custom_call(self_callable, request, *args, **kwargs): + method = getattr(self_callable, "_method", b"") + method_str = method.decode("utf-8") if isinstance(method, bytes) else method + response = orig_call(self_callable, request, *args, **kwargs) + if "ExecuteStreamingSql" in method_str: + response._is_execute_streaming_sql_mock = True + return response + + try: + with ( + mock.patch( + "google.cloud.spanner_v1.metrics.metrics_tracer_factory.get_meter_provider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client.MeterProvider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client._get_spanner_emulator_host", + return_value=None, + ), + mock.patch( + "grpc._channel._UnaryStreamMultiCallable.__call__", + custom_call, + ), + mock.patch( + "grpc._channel._MultiThreadedRendezvous.initial_metadata", + custom_initial_metadata, + ), + mock.patch( + "grpc._channel._MultiThreadedRendezvous.trailing_metadata", + custom_trailing_metadata, + ), + ): + client = Client( + project="p", + credentials=AnonymousCredentials(), + client_options=ClientOptions( + api_endpoint="localhost:" + str(MockServerTestBase.port), + ), + ) + instance = client.instance("test-instance") + database = instance.database( + "test-database", + pool=FixedSizePool(size=10), + enable_interceptors_in_tests=True, + ) + database._interceptors.append(MetricsInterceptor()) + database._spanner_api = ( + None # Force recreation with the new interceptor + ) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + # Consume the streaming results to complete the stream + list(results) + + metric_data = reader.get_metrics_data() + self.assertIsNotNone(metric_data) + metrics = { + metric.name: metric + for rm in metric_data.resource_metrics + for sm in rm.scope_metrics + for metric in sm.metrics + } + + self.assertIn("gfe_latency", metrics, f"Metrics: {list(metrics.keys())}") + gfe_metric = metrics["gfe_latency"] + point = next(iter(gfe_metric.data.data_points)) + self.assertEqual(point.sum, 55) + + finally: + pass + + def test_gfe_missing_header_count_exported(self): + add_select1_result() + reader = InMemoryMetricReader() + meter_provider = MeterProvider(metric_readers=[reader]) + + try: + with ( + mock.patch( + "google.cloud.spanner_v1.metrics.metrics_tracer_factory.get_meter_provider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client.MeterProvider", + return_value=meter_provider, + ), + mock.patch( + "google.cloud.spanner_v1.client._get_spanner_emulator_host", + return_value=None, + ), + ): + client = Client( + project="p", + credentials=AnonymousCredentials(), + client_options=ClientOptions( + api_endpoint="localhost:" + str(MockServerTestBase.port), + ), + ) + instance = client.instance("test-instance") + database = instance.database( + "test-database", + pool=FixedSizePool(size=10), + enable_interceptors_in_tests=True, + ) + database._interceptors.append(MetricsInterceptor()) + database._spanner_api = ( + None # Force recreation with the new interceptor + ) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + list(results) + + metric_data = reader.get_metrics_data() + self.assertIsNotNone(metric_data) + metrics = { + metric.name: metric + for rm in metric_data.resource_metrics + for sm in rm.scope_metrics + for metric in sm.metrics + } + + self.assertIn( + "gfe_missing_header_count", metrics, f"Metrics: {list(metrics.keys())}" + ) + missing_metric = metrics["gfe_missing_header_count"] + point = next(iter(missing_metric.data.data_points)) + self.assertGreaterEqual(point.value, 1) + finally: + pass diff --git a/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py b/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py index 6e091860b425..efa080191c9e 100644 --- a/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py +++ b/packages/google-cloud-spanner/tests/unit/test_metrics_interceptor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest @@ -41,7 +41,7 @@ def __init__(self): self.project = None self.instance = None self.database = None - self.gfe_enabled = False + self.gfe_enabled = True self.record_attempt_start = MagicMock() self.record_attempt_completion = MagicMock() self.set_method = MagicMock() @@ -99,10 +99,8 @@ def test_set_metrics_tracer_attributes(interceptor, mock_tracer_ctx): def test_intercept_with_tracer(interceptor, mock_tracer_ctx): # mock_tracer_ctx fixture sets the ContextVar - mock_tracer_ctx.gfe_enabled = False - - invoked_response = MagicMock() - invoked_response.initial_metadata.return_value = {} + invoked_response = Mock() + invoked_response.initial_metadata.return_value = [] mock_invoked_method = MagicMock(return_value=invoked_response) call_details = MagicMock( @@ -119,4 +117,5 @@ def test_intercept_with_tracer(interceptor, mock_tracer_ctx): assert response == invoked_response mock_tracer_ctx.record_attempt_start.assert_called() mock_tracer_ctx.record_attempt_completion.assert_called_once() + mock_tracer_ctx.record_gfe_metrics.assert_called_once() mock_invoked_method.assert_called_once_with("request", call_details) diff --git a/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py b/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py index 90b2f2f511f9..4769974f0c8a 100644 --- a/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py +++ b/packages/google-cloud-spanner/tests/unit/test_metrics_tracer.py @@ -264,3 +264,36 @@ def test_record_gfe_missing_header_count(metrics_tracer): metrics_tracer.record_gfe_missing_header_count() assert mock_gfe_missing_header_count.add.call_count == 1 # Should not increment metrics_tracer.enabled = True # Reset for next test + + +def test_extract_gfe_latency(): + # Valid trailing metadata list of tuples + metadata_list = [("server-timing", "gfet4t7; dur=123")] + assert MetricsTracer.extract_gfe_latency(metadata_list) == 123 + + # Valid metadata dict + metadata_dict = {"server-timing": "gfet4t7; dur=456"} + assert MetricsTracer.extract_gfe_latency(metadata_dict) == 456 + + # Missing header + assert MetricsTracer.extract_gfe_latency([("other-header", "val")]) is None + assert MetricsTracer.extract_gfe_latency(None) is None + + +def test_record_gfe_metrics(metrics_tracer): + mock_gfe_latency = mock.create_autospec(Histogram, instance=True) + mock_gfe_missing = mock.create_autospec(Counter, instance=True) + metrics_tracer._instrument_gfe_latency = mock_gfe_latency + metrics_tracer._instrument_gfe_missing_header_count = mock_gfe_missing + metrics_tracer.gfe_enabled = True + + # With header + metrics_tracer.record_gfe_metrics([("server-timing", "gfet4t7; dur=88")]) + assert mock_gfe_latency.record.call_count == 1 + assert mock_gfe_latency.record.call_args[1]["amount"] == 88 + assert mock_gfe_missing.add.call_count == 0 + + # Without header + metrics_tracer.record_gfe_metrics([("other", "1")]) + assert mock_gfe_latency.record.call_count == 1 + assert mock_gfe_missing.add.call_count == 1