diff --git a/CHANGELOG.md b/CHANGELOG.md index 70faaf89..90b45ce6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,14 @@ [1]: https://pypi.org/project/google-cloud-bigquery-storage/#history +## [2.6.0](https://www.github.com/googleapis/python-bigquery-storage/compare/v2.5.0...v2.6.0) (2021-07-09) + + +### Features + +* `read_session` optional to `ReadRowsStream.rows()` ([#228](https://www.github.com/googleapis/python-bigquery-storage/issues/228)) ([4f56029](https://www.github.com/googleapis/python-bigquery-storage/commit/4f5602950a0c1959e332aa2964245b9caf4828c8)) +* add always_use_jwt_access ([#223](https://www.github.com/googleapis/python-bigquery-storage/issues/223)) ([fd82417](https://www.github.com/googleapis/python-bigquery-storage/commit/fd824174fb044fbacc83c647f619fda556333e26)) + ## [2.5.0](https://www.github.com/googleapis/python-bigquery-storage/compare/v2.4.0...v2.5.0) (2021-06-29) diff --git a/google/cloud/bigquery_storage_v1/reader.py b/google/cloud/bigquery_storage_v1/reader.py index 034ad726..a8cd226c 100644 --- a/google/cloud/bigquery_storage_v1/reader.py +++ b/google/cloud/bigquery_storage_v1/reader.py @@ -156,7 +156,7 @@ def _reconnect(self): read_stream=self._name, offset=self._offset, **self._read_rows_kwargs ) - def rows(self, read_session): + def rows(self, read_session=None): """Iterate over all rows in the stream. This method requires the fastavro library in order to parse row @@ -169,19 +169,21 @@ def rows(self, read_session): Args: read_session ( \ - ~google.cloud.bigquery_storage_v1.types.ReadSession \ + Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \ ): - The read session associated with this read rows stream. This - contains the schema, which is required to parse the data - messages. + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. Returns: Iterable[Mapping]: A sequence of rows, represented as dictionaries. """ - return ReadRowsIterable(self, read_session) + return ReadRowsIterable(self, read_session=read_session) - def to_arrow(self, read_session): + def to_arrow(self, read_session=None): """Create a :class:`pyarrow.Table` of all rows in the stream. This method requires the pyarrow library and a stream using the Arrow @@ -191,17 +193,19 @@ def to_arrow(self, read_session): read_session ( \ ~google.cloud.bigquery_storage_v1.types.ReadSession \ ): - The read session associated with this read rows stream. This - contains the schema, which is required to parse the data - messages. + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. Returns: pyarrow.Table: A table of all rows in the stream. """ - return self.rows(read_session).to_arrow() + return self.rows(read_session=read_session).to_arrow() - def to_dataframe(self, read_session, dtypes=None): + def to_dataframe(self, read_session=None, dtypes=None): """Create a :class:`pandas.DataFrame` of all rows in the stream. This method requires the pandas libary to create a data frame and the @@ -215,9 +219,11 @@ def to_dataframe(self, read_session, dtypes=None): read_session ( \ ~google.cloud.bigquery_storage_v1.types.ReadSession \ ): - The read session associated with this read rows stream. This - contains the schema, which is required to parse the data - messages. + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. dtypes ( \ Map[str, Union[str, pandas.Series.dtype]] \ ): @@ -233,7 +239,7 @@ def to_dataframe(self, read_session, dtypes=None): if pandas is None: raise ImportError(_PANDAS_REQUIRED) - return self.rows(read_session).to_dataframe(dtypes=dtypes) + return self.rows(read_session=read_session).to_dataframe(dtypes=dtypes) class ReadRowsIterable(object): @@ -242,18 +248,25 @@ class ReadRowsIterable(object): Args: reader (google.cloud.bigquery_storage_v1.reader.ReadRowsStream): A read rows stream. - read_session (google.cloud.bigquery_storage_v1.types.ReadSession): - A read session. This is required because it contains the schema - used in the stream messages. + read_session ( \ + Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \ + ): + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. """ # This class is modelled after the google.cloud.bigquery.table.RowIterator # and aims to be API compatible where possible. - def __init__(self, reader, read_session): + def __init__(self, reader, read_session=None): self._reader = reader - self._read_session = read_session - self._stream_parser = _StreamParser.from_read_session(self._read_session) + if read_session is not None: + self._stream_parser = _StreamParser.from_read_session(read_session) + else: + self._stream_parser = None @property def pages(self): @@ -266,6 +279,10 @@ def pages(self): # Each page is an iterator of rows. But also has num_items, remaining, # and to_dataframe. for message in self._reader: + # Only the first message contains the schema, which is needed to + # decode the messages. + if not self._stream_parser: + self._stream_parser = _StreamParser.from_read_rows_response(message) yield ReadRowsPage(self._stream_parser, message) def __iter__(self): @@ -328,10 +345,11 @@ def to_dataframe(self, dtypes=None): # pandas dataframe is about 2x faster. This is because pandas.concat is # rarely no-copy, whereas pyarrow.Table.from_batches + to_pandas is # usually no-copy. - schema_type = self._read_session._pb.WhichOneof("schema") - - if schema_type == "arrow_schema": + try: record_batch = self.to_arrow() + except NotImplementedError: + pass + else: df = record_batch.to_pandas() for column in dtypes: df[column] = pandas.Series(df[column], dtype=dtypes[column]) @@ -491,6 +509,12 @@ def to_dataframe(self, message, dtypes=None): def to_rows(self, message): raise NotImplementedError("Not implemented.") + def _parse_avro_schema(self): + raise NotImplementedError("Not implemented.") + + def _parse_arrow_schema(self): + raise NotImplementedError("Not implemented.") + @staticmethod def from_read_session(read_session): schema_type = read_session._pb.WhichOneof("schema") @@ -503,22 +527,38 @@ def from_read_session(read_session): "Unsupported schema type in read_session: {0}".format(schema_type) ) + @staticmethod + def from_read_rows_response(message): + schema_type = message._pb.WhichOneof("schema") + if schema_type == "avro_schema": + return _AvroStreamParser(message) + elif schema_type == "arrow_schema": + return _ArrowStreamParser(message) + else: + raise TypeError( + "Unsupported schema type in message: {0}".format(schema_type) + ) + class _AvroStreamParser(_StreamParser): """Helper to parse Avro messages into useful representations.""" - def __init__(self, read_session): + def __init__(self, message): """Construct an _AvroStreamParser. Args: - read_session (google.cloud.bigquery_storage_v1.types.ReadSession): - A read session. This is required because it contains the schema - used in the stream messages. + message (Union[ + google.cloud.bigquery_storage_v1.types.ReadSession, \ + google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \ + ]): + Either the first message of data from a read rows stream or a + read session. Both types contain a oneof "schema" field, which + can be used to determine how to deserialize rows. """ if fastavro is None: raise ImportError(_FASTAVRO_REQUIRED) - self._read_session = read_session + self._first_message = message self._avro_schema_json = None self._fastavro_schema = None self._column_names = None @@ -548,6 +588,10 @@ def to_dataframe(self, message, dtypes=None): strings in the fastavro library. Args: + message ( \ + ~google.cloud.bigquery_storage_v1.types.ReadRowsResponse \ + ): + A message containing Avro bytes to parse into a pandas DataFrame. dtypes ( \ Map[str, Union[str, pandas.Series.dtype]] \ ): @@ -578,10 +622,11 @@ def _parse_avro_schema(self): if self._avro_schema_json: return - self._avro_schema_json = json.loads(self._read_session.avro_schema.schema) + self._avro_schema_json = json.loads(self._first_message.avro_schema.schema) self._column_names = tuple( (field["name"] for field in self._avro_schema_json["fields"]) ) + self._first_message = None def _parse_fastavro(self): """Convert parsed Avro schema to fastavro format.""" @@ -615,11 +660,22 @@ def to_rows(self, message): class _ArrowStreamParser(_StreamParser): - def __init__(self, read_session): + def __init__(self, message): + """Construct an _ArrowStreamParser. + + Args: + message (Union[ + google.cloud.bigquery_storage_v1.types.ReadSession, \ + google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \ + ]): + Either the first message of data from a read rows stream or a + read session. Both types contain a oneof "schema" field, which + can be used to determine how to deserialize rows. + """ if pyarrow is None: raise ImportError(_PYARROW_REQUIRED) - self._read_session = read_session + self._first_message = message self._schema = None def to_arrow(self, message): @@ -659,6 +715,7 @@ def _parse_arrow_schema(self): return self._schema = pyarrow.ipc.read_schema( - pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema) + pyarrow.py_buffer(self._first_message.arrow_schema.serialized_schema) ) self._column_names = [field.name for field in self._schema] + self._first_message = None diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py index e43dd679..af7f5390 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/base.py @@ -24,6 +24,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.bigquery_storage_v1.types import storage from google.cloud.bigquery_storage_v1.types import stream @@ -46,8 +47,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class BigQueryReadTransport(abc.ABC): """Abstract transport class for BigQueryRead.""" @@ -69,6 +68,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -92,6 +92,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -101,7 +103,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -120,13 +122,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -147,27 +156,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py index 28905942..6cb890e9 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc.py @@ -59,6 +59,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -99,6 +100,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -151,6 +154,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -206,14 +210,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py index 24e8fa96..fd5ecad0 100644 --- a/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1/services/big_query_read/transports/grpc_asyncio.py @@ -80,14 +80,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -105,6 +105,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -146,6 +147,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -197,6 +200,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py index efe4bead..fadedc4d 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/base.py @@ -24,6 +24,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.bigquery_storage_v1beta2.types import storage from google.cloud.bigquery_storage_v1beta2.types import stream @@ -46,8 +47,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class BigQueryReadTransport(abc.ABC): """Abstract transport class for BigQueryRead.""" @@ -69,6 +68,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -92,6 +92,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -101,7 +103,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -120,13 +122,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -147,27 +156,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py index bb7d3252..54f5fed1 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc.py @@ -61,6 +61,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -101,6 +102,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -153,6 +156,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -208,14 +212,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py index aa017209..7cb0784f 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_read/transports/grpc_asyncio.py @@ -82,14 +82,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -107,6 +107,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -148,6 +149,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -199,6 +202,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py index f0006332..7286ce5b 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/base.py @@ -24,6 +24,7 @@ from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.bigquery_storage_v1beta2.types import storage from google.cloud.bigquery_storage_v1beta2.types import stream @@ -46,8 +47,6 @@ except pkg_resources.DistributionNotFound: # pragma: NO COVER _GOOGLE_AUTH_VERSION = None -_API_CORE_VERSION = google.api_core.__version__ - class BigQueryWriteTransport(abc.ABC): """Abstract transport class for BigQueryWrite.""" @@ -69,6 +68,7 @@ def __init__( scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, **kwargs, ) -> None: """Instantiate the transport. @@ -92,6 +92,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -101,7 +103,7 @@ def __init__( scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) # Save the scopes. - self._scopes = scopes or self.AUTH_SCOPES + self._scopes = scopes # If no credentials are provided, then determine the appropriate # defaults. @@ -120,13 +122,20 @@ def __init__( **scopes_kwargs, quota_project_id=quota_project_id ) + # If the credentials is service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + # Save the credentials. self._credentials = credentials - # TODO(busunkim): These two class methods are in the base transport + # TODO(busunkim): This method is in the base transport # to avoid duplicating code across the transport classes. These functions - # should be deleted once the minimum required versions of google-api-core - # and google-auth are increased. + # should be deleted once the minimum required versions of google-auth is increased. # TODO: Remove this function once google-auth >= 1.25.0 is required @classmethod @@ -147,27 +156,6 @@ def _get_scopes_kwargs( return scopes_kwargs - # TODO: Remove this function once google-api-core >= 1.26.0 is required - @classmethod - def _get_self_signed_jwt_kwargs( - cls, host: str, scopes: Optional[Sequence[str]] - ) -> Dict[str, Union[Optional[Sequence[str]], str]]: - """Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version""" - - self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {} - - if _API_CORE_VERSION and ( - packaging.version.parse(_API_CORE_VERSION) - >= packaging.version.parse("1.26.0") - ): - self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES - self_signed_jwt_kwargs["scopes"] = scopes - self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST - else: - self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES - - return self_signed_jwt_kwargs - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py index 50d4c8d6..b9013f22 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc.py @@ -59,6 +59,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -99,6 +100,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -151,6 +154,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: @@ -206,14 +210,14 @@ def create_channel( and ``credentials_file`` are passed. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) diff --git a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py index 23f15d49..41597592 100644 --- a/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py +++ b/google/cloud/bigquery_storage_v1beta2/services/big_query_write/transports/grpc_asyncio.py @@ -80,14 +80,14 @@ def create_channel( aio.Channel: A gRPC AsyncIO channel object. """ - self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes) - return grpc_helpers_async.create_channel( host, credentials=credentials, credentials_file=credentials_file, quota_project_id=quota_project_id, - **self_signed_jwt_kwargs, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, **kwargs, ) @@ -105,6 +105,7 @@ def __init__( client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, ) -> None: """Instantiate the transport. @@ -146,6 +147,8 @@ def __init__( API requests. If ``None``, then default info will be used. Generally, you only need to set this if you're developing your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -197,6 +200,7 @@ def __init__( scopes=scopes, quota_project_id=quota_project_id, client_info=client_info, + always_use_jwt_access=always_use_jwt_access, ) if not self._grpc_channel: diff --git a/samples/quickstart/requirements.txt b/samples/quickstart/requirements.txt index f3466a91..9eddd046 100644 --- a/samples/quickstart/requirements.txt +++ b/samples/quickstart/requirements.txt @@ -1,2 +1,2 @@ fastavro -google-cloud-bigquery-storage==2.4.0 +google-cloud-bigquery-storage==2.5.0 diff --git a/samples/to_dataframe/requirements.txt b/samples/to_dataframe/requirements.txt index 455e6894..45aeaa2f 100644 --- a/samples/to_dataframe/requirements.txt +++ b/samples/to_dataframe/requirements.txt @@ -1,5 +1,5 @@ -google-auth==1.32.0 -google-cloud-bigquery-storage==2.4.0 +google-auth==1.32.1 +google-cloud-bigquery-storage==2.5.0 google-cloud-bigquery==2.20.0 pyarrow==4.0.1 ipython==7.10.2; python_version > '3.0' diff --git a/setup.py b/setup.py index 6f77b670..59c70ba2 100644 --- a/setup.py +++ b/setup.py @@ -21,10 +21,10 @@ name = "google-cloud-bigquery-storage" description = "BigQuery Storage API API client library" -version = "2.5.0" +version = "2.6.0" release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", + "google-api-core[grpc] >= 1.26.0, < 2.0.0dev", "proto-plus >= 1.4.0", "packaging >= 14.3", "libcst >= 0.2.5", diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 52ae0efd..f9186709 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -5,11 +5,11 @@ # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", # Then this file should have foo==1.14.0 -google-api-core==1.22.2 +google-api-core==1.26.0 proto-plus==1.4.0 libcst==0.2.5 fastavro==0.21.2 pandas==0.21.1 pyarrow==0.15.0 packaging==14.3 -google-auth==1.24.0 # TODO: remove when google-auth>=1.25.0 si transitively required through google-api-core +google-auth==1.24.0 # TODO: remove when google-auth>=1.25.0 is transitively required through google-api-core diff --git a/tests/system/conftest.py b/tests/system/conftest.py index a18777dd..3a89097a 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -18,13 +18,41 @@ import os import uuid +import google.auth +from google.cloud import bigquery import pytest +import test_utils.prefixer from . import helpers +prefixer = test_utils.prefixer.Prefixer("python-bigquery-storage", "tests/system") + + _TABLE_FORMAT = "projects/{}/datasets/{}/tables/{}" _ASSETS_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "assets") +_ALL_TYPES_SCHEMA = [ + bigquery.SchemaField("string_field", "STRING"), + bigquery.SchemaField("bytes_field", "BYTES"), + bigquery.SchemaField("int64_field", "INT64"), + bigquery.SchemaField("float64_field", "FLOAT64"), + bigquery.SchemaField("numeric_field", "NUMERIC"), + bigquery.SchemaField("bool_field", "BOOL"), + bigquery.SchemaField("geography_field", "GEOGRAPHY"), + bigquery.SchemaField( + "person_struct_field", + "STRUCT", + fields=( + bigquery.SchemaField("name", "STRING"), + bigquery.SchemaField("age", "INT64"), + ), + ), + bigquery.SchemaField("timestamp_field", "TIMESTAMP"), + bigquery.SchemaField("date_field", "DATE"), + bigquery.SchemaField("time_field", "TIME"), + bigquery.SchemaField("datetime_field", "DATETIME"), + bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"), +] @pytest.fixture(scope="session") @@ -38,18 +66,9 @@ def use_mtls(): @pytest.fixture(scope="session") -def credentials(use_mtls): - import google.auth - from google.oauth2 import service_account - - if use_mtls: - # mTLS test uses user credentials instead of service account credentials - creds, _ = google.auth.default() - return creds - - # NOTE: the test config in noxfile checks that the env variable is indeed set - filename = os.environ["GOOGLE_APPLICATION_CREDENTIALS"] - return service_account.Credentials.from_service_account_file(filename) +def credentials(): + creds, _ = google.auth.default() + return creds @pytest.fixture() @@ -77,8 +96,7 @@ def local_shakespeare_table_reference(project_id, use_mtls): def dataset(project_id, bq_client): from google.cloud import bigquery - unique_suffix = str(uuid.uuid4()).replace("-", "_") - dataset_name = "bq_storage_system_tests_" + unique_suffix + dataset_name = prefixer.create_prefix() dataset_id = "{}.{}".format(project_id, dataset_name) dataset = bigquery.Dataset(dataset_id) @@ -120,35 +138,20 @@ def bq_client(credentials, use_mtls): return bigquery.Client(credentials=credentials) +@pytest.fixture(scope="session", autouse=True) +def cleanup_datasets(bq_client: bigquery.Client): + for dataset in bq_client.list_datasets(): + if prefixer.should_cleanup(dataset.dataset_id): + bq_client.delete_dataset(dataset, delete_contents=True, not_found_ok=True) + + @pytest.fixture def all_types_table_ref(project_id, dataset, bq_client): from google.cloud import bigquery - schema = [ - bigquery.SchemaField("string_field", "STRING"), - bigquery.SchemaField("bytes_field", "BYTES"), - bigquery.SchemaField("int64_field", "INT64"), - bigquery.SchemaField("float64_field", "FLOAT64"), - bigquery.SchemaField("numeric_field", "NUMERIC"), - bigquery.SchemaField("bool_field", "BOOL"), - bigquery.SchemaField("geography_field", "GEOGRAPHY"), - bigquery.SchemaField( - "person_struct_field", - "STRUCT", - fields=( - bigquery.SchemaField("name", "STRING"), - bigquery.SchemaField("age", "INT64"), - ), - ), - bigquery.SchemaField("timestamp_field", "TIMESTAMP"), - bigquery.SchemaField("date_field", "DATE"), - bigquery.SchemaField("time_field", "TIME"), - bigquery.SchemaField("datetime_field", "DATETIME"), - bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"), - ] bq_table = bigquery.table.Table( table_ref="{}.{}.complex_records".format(project_id, dataset.dataset_id), - schema=schema, + schema=_ALL_TYPES_SCHEMA, ) created_table = bq_client.create_table(bq_table) diff --git a/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py b/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py index 9253a0a0..60d9a4bd 100644 --- a/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py +++ b/tests/unit/gapic/bigquery_storage_v1/test_big_query_read.py @@ -36,9 +36,6 @@ ) from google.cloud.bigquery_storage_v1.services.big_query_read import BigQueryReadClient from google.cloud.bigquery_storage_v1.services.big_query_read import transports -from google.cloud.bigquery_storage_v1.services.big_query_read.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.bigquery_storage_v1.services.big_query_read.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -51,8 +48,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -63,16 +61,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -130,6 +118,34 @@ def test_big_query_read_client_from_service_account_info(client_class): assert client.transport._host == "bigquerystorage.googleapis.com:443" +@pytest.mark.parametrize("client_class", [BigQueryReadClient, BigQueryReadAsyncClient,]) +def test_big_query_read_client_service_account_always_use_jwt(client_class): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + client = client_class(credentials=creds) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.BigQueryReadGrpcTransport, "grpc"), + (transports.BigQueryReadGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_big_query_read_client_service_account_always_use_jwt_true( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + @pytest.mark.parametrize("client_class", [BigQueryReadClient, BigQueryReadAsyncClient,]) def test_big_query_read_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -1299,7 +1315,6 @@ def test_big_query_read_transport_auth_adc_old_google_auth(transport_class): (transports.BigQueryReadGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1332,83 +1347,6 @@ def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.BigQueryReadGrpcTransport, grpc_helpers), - (transports.BigQueryReadGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_big_query_read_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "bigquerystorage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.BigQueryReadGrpcTransport, grpc_helpers), - (transports.BigQueryReadGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_big_query_read_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "bigquerystorage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.BigQueryReadGrpcTransport, transports.BigQueryReadGrpcAsyncIOTransport], @@ -1428,11 +1366,7 @@ def test_big_query_read_grpc_transport_client_cert_source_for_mtls(transport_cla "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1536,11 +1470,7 @@ def test_big_query_read_transport_channel_mtls_with_client_cert_source(transport "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1584,11 +1514,7 @@ def test_big_query_read_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py index e7481e67..407b2e5c 100644 --- a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py +++ b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_read.py @@ -38,9 +38,6 @@ BigQueryReadClient, ) from google.cloud.bigquery_storage_v1beta2.services.big_query_read import transports -from google.cloud.bigquery_storage_v1beta2.services.big_query_read.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.bigquery_storage_v1beta2.services.big_query_read.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -53,8 +50,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -65,16 +63,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -132,6 +120,34 @@ def test_big_query_read_client_from_service_account_info(client_class): assert client.transport._host == "bigquerystorage.googleapis.com:443" +@pytest.mark.parametrize("client_class", [BigQueryReadClient, BigQueryReadAsyncClient,]) +def test_big_query_read_client_service_account_always_use_jwt(client_class): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + client = client_class(credentials=creds) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.BigQueryReadGrpcTransport, "grpc"), + (transports.BigQueryReadGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_big_query_read_client_service_account_always_use_jwt_true( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + @pytest.mark.parametrize("client_class", [BigQueryReadClient, BigQueryReadAsyncClient,]) def test_big_query_read_client_from_service_account_file(client_class): creds = ga_credentials.AnonymousCredentials() @@ -1301,7 +1317,6 @@ def test_big_query_read_transport_auth_adc_old_google_auth(transport_class): (transports.BigQueryReadGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1334,83 +1349,6 @@ def test_big_query_read_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.BigQueryReadGrpcTransport, grpc_helpers), - (transports.BigQueryReadGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_big_query_read_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "bigquerystorage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.BigQueryReadGrpcTransport, grpc_helpers), - (transports.BigQueryReadGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_big_query_read_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "bigquerystorage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [transports.BigQueryReadGrpcTransport, transports.BigQueryReadGrpcAsyncIOTransport], @@ -1430,11 +1368,7 @@ def test_big_query_read_grpc_transport_client_cert_source_for_mtls(transport_cla "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -1538,11 +1472,7 @@ def test_big_query_read_transport_channel_mtls_with_client_cert_source(transport "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -1586,11 +1516,7 @@ def test_big_query_read_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.readonly", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py index c0172c65..5557219f 100644 --- a/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py +++ b/tests/unit/gapic/bigquery_storage_v1beta2/test_big_query_write.py @@ -38,9 +38,6 @@ BigQueryWriteClient, ) from google.cloud.bigquery_storage_v1beta2.services.big_query_write import transports -from google.cloud.bigquery_storage_v1beta2.services.big_query_write.transports.base import ( - _API_CORE_VERSION, -) from google.cloud.bigquery_storage_v1beta2.services.big_query_write.transports.base import ( _GOOGLE_AUTH_VERSION, ) @@ -56,8 +53,9 @@ import google.auth -# TODO(busunkim): Once google-api-core >= 1.26.0 is required: -# - Delete all the api-core and auth "less than" test cases +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases # - Delete these pytest markers (Make the "greater than or equal to" tests the default). requires_google_auth_lt_1_25_0 = pytest.mark.skipif( packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), @@ -68,16 +66,6 @@ reason="This test requires google-auth >= 1.25.0", ) -requires_api_core_lt_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) >= packaging.version.parse("1.26.0"), - reason="This test requires google-api-core < 1.26.0", -) - -requires_api_core_gte_1_26_0 = pytest.mark.skipif( - packaging.version.parse(_API_CORE_VERSION) < packaging.version.parse("1.26.0"), - reason="This test requires google-api-core >= 1.26.0", -) - def client_cert_source_callback(): return b"cert bytes", b"key bytes" @@ -140,6 +128,36 @@ def test_big_query_write_client_from_service_account_info(client_class): assert client.transport._host == "bigquerystorage.googleapis.com:443" +@pytest.mark.parametrize( + "client_class", [BigQueryWriteClient, BigQueryWriteAsyncClient,] +) +def test_big_query_write_client_service_account_always_use_jwt(client_class): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + client = client_class(credentials=creds) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.BigQueryWriteGrpcTransport, "grpc"), + (transports.BigQueryWriteGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_big_query_write_client_service_account_always_use_jwt_true( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + @pytest.mark.parametrize( "client_class", [BigQueryWriteClient, BigQueryWriteAsyncClient,] ) @@ -1872,7 +1890,6 @@ def test_big_query_write_transport_auth_adc_old_google_auth(transport_class): (transports.BigQueryWriteGrpcAsyncIOTransport, grpc_helpers_async), ], ) -@requires_api_core_gte_1_26_0 def test_big_query_write_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. @@ -1905,83 +1922,6 @@ def test_big_query_write_transport_create_channel(transport_class, grpc_helpers) ) -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.BigQueryWriteGrpcTransport, grpc_helpers), - (transports.BigQueryWriteGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_big_query_write_transport_create_channel_old_api_core( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class(quota_project_id="octopus") - - create_channel.assert_called_with( - "bigquerystorage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.insertdata", - "https://www.googleapis.com/auth/cloud-platform", - ), - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.BigQueryWriteGrpcTransport, grpc_helpers), - (transports.BigQueryWriteGrpcAsyncIOTransport, grpc_helpers_async), - ], -) -@requires_api_core_lt_1_26_0 -def test_big_query_write_transport_create_channel_user_scopes( - transport_class, grpc_helpers -): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - - create_channel.assert_called_with( - "bigquerystorage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - scopes=["1", "2"], - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - @pytest.mark.parametrize( "transport_class", [ @@ -2004,11 +1944,7 @@ def test_big_query_write_grpc_transport_client_cert_source_for_mtls(transport_cl "squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.insertdata", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_channel_creds, quota_project_id=None, options=[ @@ -2117,11 +2053,7 @@ def test_big_query_write_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.insertdata", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ @@ -2168,11 +2100,7 @@ def test_big_query_write_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/bigquery.insertdata", - "https://www.googleapis.com/auth/cloud-platform", - ), + scopes=None, ssl_credentials=mock_ssl_cred, quota_project_id=None, options=[ diff --git a/tests/unit/test_reader_v1.py b/tests/unit/test_reader_v1.py index 7fb8d5a4..838ef51a 100644 --- a/tests/unit/test_reader_v1.py +++ b/tests/unit/test_reader_v1.py @@ -66,6 +66,7 @@ def mock_gapic_client(): def _bq_to_avro_blocks(bq_blocks, avro_schema_json): avro_schema = fastavro.parse_schema(avro_schema_json) avro_blocks = [] + first_message = True for block in bq_blocks: blockio = six.BytesIO() for row in block: @@ -73,6 +74,9 @@ def _bq_to_avro_blocks(bq_blocks, avro_schema_json): response = types.ReadRowsResponse() response.row_count = len(block) response.avro_rows.serialized_binary_rows = blockio.getvalue() + if first_message: + response.avro_schema = {"schema": json.dumps(avro_schema_json)} + first_message = False avro_blocks.append(response) return avro_blocks @@ -128,54 +132,48 @@ def _bq_to_avro_schema(bq_columns): return avro_schema -def _get_avro_bytes(rows, avro_schema): - avro_file = six.BytesIO() - for row in rows: - fastavro.schemaless_writer(avro_file, avro_schema, row) - return avro_file.getvalue() - - def test_avro_rows_raises_import_error( mut, class_under_test, mock_gapic_client, monkeypatch ): monkeypatch.setattr(mut, "fastavro", None) - reader = class_under_test([], mock_gapic_client, "", 0, {}) - - bq_columns = [{"name": "int_col", "type": "int64"}] - avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) + avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) + avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) + reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + rows = iter(reader.rows()) + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(ImportError): - reader.rows(read_session) + next(rows) def test_rows_no_schema_set_raises_type_error( mut, class_under_test, mock_gapic_client, monkeypatch ): - reader = class_under_test([], mock_gapic_client, "", 0, {}) - read_session = types.ReadSession() + avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) + avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) + avro_blocks[0].avro_schema = None + reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + rows = iter(reader.rows()) + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(TypeError): - reader.rows(read_session) + next(rows) def test_rows_w_empty_stream(class_under_test, mock_gapic_client): - bq_columns = [{"name": "int_col", "type": "int64"}] - avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) reader = class_under_test([], mock_gapic_client, "", 0, {}) - - got = reader.rows(read_session) + got = reader.rows() assert tuple(got) == () def test_rows_w_scalars(class_under_test, mock_gapic_client): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) - got = tuple(reader.rows(read_session)) + got = tuple(reader.rows()) expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS)) assert got == expected @@ -184,7 +182,6 @@ def test_rows_w_scalars(class_under_test, mock_gapic_client): def test_rows_w_timeout(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -206,7 +203,7 @@ def test_rows_w_timeout(class_under_test, mock_gapic_client): ) with pytest.raises(google.api_core.exceptions.DeadlineExceeded): - list(reader.rows(read_session)) + list(reader.rows()) # Don't reconnect on DeadlineException. This allows user-specified timeouts # to be respected. @@ -216,7 +213,6 @@ def test_rows_w_timeout(class_under_test, mock_gapic_client): def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks = [[{"int_col": 1024}, {"int_col": 512}], [{"int_col": 256}]] avro_blocks = _pages_w_nonresumable_internal_error( _bq_to_avro_blocks(bq_blocks, avro_schema) @@ -227,7 +223,7 @@ def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client) with pytest.raises( google.api_core.exceptions.InternalServerError, match="nonresumable error" ): - list(reader.rows(read_session)) + list(reader.rows()) mock_gapic_client.read_rows.assert_not_called() @@ -235,7 +231,6 @@ def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client) def test_rows_w_reconnect(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -258,7 +253,7 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): 0, {"metadata": {"test-key": "test-value"}}, ) - got = reader.rows(read_session) + got = reader.rows() expected = tuple( itertools.chain( @@ -280,7 +275,6 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): def test_rows_w_reconnect_by_page(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -298,7 +292,7 @@ def test_rows_w_reconnect_by_page(class_under_test, mock_gapic_client): 0, {"metadata": {"test-key": "test-value"}}, ) - got = reader.rows(read_session) + got = reader.rows() pages = iter(got.pages) page_1 = next(pages) @@ -330,38 +324,41 @@ def test_to_dataframe_no_pandas_raises_import_error( ): monkeypatch.setattr(mut, "pandas", None) avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) with pytest.raises(ImportError): - reader.to_dataframe(read_session) + reader.to_dataframe() with pytest.raises(ImportError): - reader.rows(read_session).to_dataframe() + reader.rows().to_dataframe() with pytest.raises(ImportError): - next(reader.rows(read_session).pages).to_dataframe() + next(reader.rows().pages).to_dataframe() def test_to_dataframe_no_schema_set_raises_type_error( mut, class_under_test, mock_gapic_client, monkeypatch ): - reader = class_under_test([], mock_gapic_client, "", 0, {}) - read_session = types.ReadSession() + avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) + avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) + avro_blocks[0].avro_schema = None + reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + rows = reader.rows() + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(TypeError): - reader.to_dataframe(read_session) + rows.to_dataframe() def test_to_dataframe_w_scalars(class_under_test): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session) + got = reader.to_dataframe() expected = pandas.DataFrame( list(itertools.chain.from_iterable(SCALAR_BLOCKS)), columns=SCALAR_COLUMN_NAMES @@ -392,7 +389,6 @@ def test_to_dataframe_w_dtypes(class_under_test): {"name": "lilfloat", "type": "float64"}, ] ) - read_session = _generate_avro_read_session(avro_schema) blocks = [ [{"bigfloat": 1.25, "lilfloat": 30.5}, {"bigfloat": 2.5, "lilfloat": 21.125}], [{"bigfloat": 3.75, "lilfloat": 11.0}], @@ -400,7 +396,7 @@ def test_to_dataframe_w_dtypes(class_under_test): avro_blocks = _bq_to_avro_blocks(blocks, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) + got = reader.to_dataframe(dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame( { @@ -421,6 +417,7 @@ def test_to_dataframe_empty_w_scalars_avro(class_under_test): avro_blocks = _bq_to_avro_blocks([], avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session) expected = pandas.DataFrame(columns=SCALAR_COLUMN_NAMES) @@ -448,6 +445,7 @@ def test_to_dataframe_empty_w_dtypes_avro(class_under_test, mock_gapic_client): avro_blocks = _bq_to_avro_blocks([], avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame([], columns=["bigfloat", "lilfloat"]) @@ -466,7 +464,6 @@ def test_to_dataframe_by_page(class_under_test, mock_gapic_client): {"name": "bool_col", "type": "bool"}, ] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) block_1 = [{"int_col": 123, "bool_col": True}, {"int_col": 234, "bool_col": False}] block_2 = [{"int_col": 345, "bool_col": True}, {"int_col": 456, "bool_col": False}] block_3 = [{"int_col": 567, "bool_col": True}, {"int_col": 789, "bool_col": False}] @@ -487,7 +484,7 @@ def test_to_dataframe_by_page(class_under_test, mock_gapic_client): 0, {"metadata": {"test-key": "test-value"}}, ) - got = reader.rows(read_session) + got = reader.rows() pages = iter(got.pages) page_1 = next(pages) diff --git a/tests/unit/test_reader_v1_arrow.py b/tests/unit/test_reader_v1_arrow.py index 492098f5..02c7b80a 100644 --- a/tests/unit/test_reader_v1_arrow.py +++ b/tests/unit/test_reader_v1_arrow.py @@ -84,11 +84,17 @@ def _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): def _bq_to_arrow_batches(bq_blocks, arrow_schema): arrow_batches = [] + first_message = True for record_batch in _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): response = types.ReadRowsResponse() response.arrow_record_batch.serialized_record_batch = ( record_batch.serialize().to_pybytes() ) + if first_message: + response.arrow_schema = { + "serialized_schema": arrow_schema.serialize().to_pybytes(), + } + first_message = False arrow_batches.append(response) return arrow_batches @@ -123,14 +129,15 @@ def test_pyarrow_rows_raises_import_error( mut, class_under_test, mock_gapic_client, monkeypatch ): monkeypatch.setattr(mut, "pyarrow", None) - reader = class_under_test([], mock_gapic_client, "", 0, {}) - - bq_columns = [{"name": "int_col", "type": "int64"}] - arrow_schema = _bq_to_arrow_schema(bq_columns) - read_session = _generate_arrow_read_session(arrow_schema) + arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) + arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) + reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + rows = iter(reader.rows()) + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(ImportError): - reader.rows(read_session) + next(rows) def test_to_arrow_no_pyarrow_raises_import_error( @@ -138,26 +145,24 @@ def test_to_arrow_no_pyarrow_raises_import_error( ): monkeypatch.setattr(mut, "pyarrow", None) arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) with pytest.raises(ImportError): - reader.to_arrow(read_session) + reader.to_arrow() with pytest.raises(ImportError): - reader.rows(read_session).to_arrow() + reader.rows().to_arrow() with pytest.raises(ImportError): - next(reader.rows(read_session).pages).to_arrow() + next(reader.rows().pages).to_arrow() def test_to_arrow_w_scalars_arrow(class_under_test): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) - actual_table = reader.to_arrow(read_session) + actual_table = reader.to_arrow() expected_table = pyarrow.Table.from_batches( _bq_to_arrow_batch_objects(SCALAR_BLOCKS, arrow_schema) ) @@ -166,11 +171,10 @@ def test_to_arrow_w_scalars_arrow(class_under_test): def test_to_dataframe_w_scalars_arrow(class_under_test): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session) + got = reader.to_dataframe() expected = pandas.DataFrame( list(itertools.chain.from_iterable(SCALAR_BLOCKS)), columns=SCALAR_COLUMN_NAMES @@ -183,24 +187,19 @@ def test_to_dataframe_w_scalars_arrow(class_under_test): def test_rows_w_empty_stream_arrow(class_under_test, mock_gapic_client): - bq_columns = [{"name": "int_col", "type": "int64"}] - arrow_schema = _bq_to_arrow_schema(bq_columns) - read_session = _generate_arrow_read_session(arrow_schema) reader = class_under_test([], mock_gapic_client, "", 0, {}) - - got = reader.rows(read_session) + got = reader.rows() assert tuple(got) == () def test_rows_w_scalars_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) got = tuple( dict((key, value.as_py()) for key, value in row_dict.items()) - for row_dict in reader.rows(read_session) + for row_dict in reader.rows() ) expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS)) @@ -214,7 +213,6 @@ def test_to_dataframe_w_dtypes_arrow(class_under_test): {"name": "lilfloat", "type": "float64"}, ] ) - read_session = _generate_arrow_read_session(arrow_schema) blocks = [ [{"bigfloat": 1.25, "lilfloat": 30.5}, {"bigfloat": 2.5, "lilfloat": 21.125}], [{"bigfloat": 3.75, "lilfloat": 11.0}], @@ -222,7 +220,7 @@ def test_to_dataframe_w_dtypes_arrow(class_under_test): arrow_batches = _bq_to_arrow_batches(blocks, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) + got = reader.to_dataframe(dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame( { @@ -243,6 +241,7 @@ def test_to_dataframe_empty_w_scalars_arrow(class_under_test): arrow_batches = _bq_to_arrow_batches([], arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session) expected = pandas.DataFrame([], columns=SCALAR_COLUMN_NAMES) @@ -270,6 +269,7 @@ def test_to_dataframe_empty_w_dtypes_arrow(class_under_test, mock_gapic_client): arrow_batches = _bq_to_arrow_batches([], arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame([], columns=["bigfloat", "lilfloat"]) @@ -288,7 +288,6 @@ def test_to_dataframe_by_page_arrow(class_under_test, mock_gapic_client): {"name": "bool_col", "type": "bool"}, ] arrow_schema = _bq_to_arrow_schema(bq_columns) - read_session = _generate_arrow_read_session(arrow_schema) bq_block_1 = [ {"int_col": 123, "bool_col": True}, @@ -315,7 +314,7 @@ def test_to_dataframe_by_page_arrow(class_under_test, mock_gapic_client): reader = class_under_test( _pages_w_unavailable(batch_1), mock_gapic_client, "", 0, {} ) - got = reader.rows(read_session) + got = reader.rows() pages = iter(got.pages) page_1 = next(pages)