Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions packages/sqlalchemy-bigquery/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[mypy]
python_version = 3.14
namespace_packages = True
ignore_missing_imports = False

# Helps mypy navigate the 'google' namespace more reliably in 3.10+
explicit_package_bases = True

# Performance: reuse results from previous runs to speed up 'nox'
incremental = True

# Optional dependency for GIS/spatial support. Lacks type hints.
[mypy-geoalchemy2.*]
ignore_missing_imports = True

# Optional dependency used by geoalchemy2. Lacks type hints.
[mypy-shapely.*]
ignore_missing_imports = True

# Optional dependency for database migrations. Lacks type hints.
[mypy-alembic.*]
ignore_missing_imports = True

# Obsolete package checked in __init__.py only to warn users that it is not supported
# and that they should use sqlalchemy-bigquery instead. Not installed.
[mypy-pybigquery.*]
ignore_missing_imports = True

# TODO(https://github.com/googleapis/gapic-generator-python/issues/2563):
# Dependencies that historically lacks py.typed markers
[mypy-google.iam.*]
ignore_missing_imports = True
18 changes: 14 additions & 4 deletions packages/sqlalchemy-bigquery/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,12 +683,22 @@ def prerelease_deps(session, protobuf_implementation):
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
@nox.session(python=ALL_PYTHON)
def mypy(session):
"""Run the type checker."""
# TODO(https://github.com/googleapis/google-cloud-python/issues/16014):
# Add mypy tests
session.skip("mypy tests are not yet supported")
session.install(
"mypy<1.16.0",
"types-requests",
"types-protobuf",
)
session.install(".")
session.run(
"mypy",
"-p",
"sqlalchemy_bigquery",
"--check-untyped-defs",
*session.posargs,
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def create_bigquery_client(
if project_id is None:
project_id = default_project

client_info = google_client_info(user_agent=user_agent)
client_info = user_agent if user_agent is not None else google_client_info()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for assigning client_info appears to be incorrect and introduces a regression. google_client_info (typically an alias for google.api_core.gapic_v1.client_info.ClientInfo) is a class that should be instantiated to return a ClientInfo object. The original code google_client_info(user_agent=user_agent) correctly handled both None and string values for user_agent to produce a valid ClientInfo instance. The new code assigns the raw user_agent string to client_info when it is not None, which will cause an AttributeError when bigquery.Client attempts to access ClientInfo methods (such as to_user_agent()) on what is now a string object.

Suggested change
client_info = user_agent if user_agent is not None else google_client_info()
client_info = google_client_info(user_agent=user_agent)


return bigquery.Client(
client_info=client_info,
Expand Down
8 changes: 6 additions & 2 deletions packages/sqlalchemy-bigquery/sqlalchemy_bigquery/_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,19 @@ def _setup_getitem(self, name):
f"STRUCT fields can only be accessed with strings field names,"
f" not {repr(name)}."
)
subtype = self.expr.type._STRUCT_byname.get(name.lower())
struct_type = self.expr.type
assert isinstance(struct_type, STRUCT)
subtype = struct_type._STRUCT_byname.get(name.lower())
if subtype is None:
raise KeyError(name)
operator = operators.json_getitem_op
index = _field_index(self, name, operator)
return operator, index, subtype

def __getattr__(self, name):
if name.lower() in self.expr.type._STRUCT_byname:
struct_type = self.expr.type
assert isinstance(struct_type, STRUCT)
if name.lower() in struct_type._STRUCT_byname:
return self[name]
else:
raise AttributeError(name)
Expand Down
17 changes: 9 additions & 8 deletions packages/sqlalchemy-bigquery/sqlalchemy_bigquery/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,33 +94,34 @@ def _get_transitive_schema_fields(fields):


def _get_sqla_column_type(field):
col_instance: sqlalchemy.types.TypeEngine
try:
coltype = _type_map[field.field_type]
col_class = _type_map[field.field_type]
except KeyError:
sqlalchemy.util.warn(
"Did not recognize type '%s' of column '%s'"
% (field.field_type, field.name)
)
coltype = sqlalchemy.types.NullType
col_instance = sqlalchemy.types.NullType()
else:
if field.field_type.endswith("NUMERIC"):
coltype = coltype(precision=field.precision, scale=field.scale)
col_instance = col_class(precision=field.precision, scale=field.scale)
elif field.field_type == "STRING" or field.field_type == "BYTES":
coltype = coltype(field.max_length)
col_instance = col_class(field.max_length)
elif field.field_type == "RECORD" or field.field_type == "STRUCT":
coltype = STRUCT(
col_instance = STRUCT(
*(
(subfield.name, _get_sqla_column_type(subfield))
for subfield in field.fields
)
)
else:
coltype = coltype()
col_instance = col_class()

if field.mode == "REPEATED":
coltype = ARRAY(coltype)
col_instance = ARRAY(col_instance)

return coltype
return col_instance


def get_columns(bq_schema):
Expand Down
96 changes: 59 additions & 37 deletions packages/sqlalchemy-bigquery/sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import random
import re
import uuid
from typing import Any

from google import auth
import google.api_core.exceptions
Expand Down Expand Up @@ -124,8 +125,9 @@ class BigQueryExecutionContext(DefaultExecutionContext):
def create_cursor(self):
# Set arraysize
c = super(BigQueryExecutionContext, self).create_cursor()
if self.dialect.arraysize:
c.arraysize = self.dialect.arraysize
arraysize = getattr(self.dialect, "arraysize", None)
if arraysize:
c.arraysize = arraysize
return c

def get_insert_default(self, column): # pragma: NO COVER
Expand Down Expand Up @@ -191,17 +193,21 @@ def pre_exec(self):

class BigQueryCompiler(vendored_postgresql.PGCompiler, SQLCompiler):
compound_keywords = SQLCompiler.compound_keywords.copy()
compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT"
compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL"
compound_keywords[selectable.CompoundSelect.EXCEPT] = "EXCEPT DISTINCT"
compound_keywords[selectable.CompoundSelect.INTERSECT] = "INTERSECT DISTINCT"
# The following attributes are not exposed in SQLAlchemy stubs for CompoundSelect but work at runtime.
# mypy has trouble resolving the type for these, hence the ignore pragma.
compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT" # type: ignore[attr-defined]
compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL" # type: ignore[attr-defined]
compound_keywords[selectable.CompoundSelect.EXCEPT] = "EXCEPT DISTINCT" # type: ignore[attr-defined]
compound_keywords[selectable.CompoundSelect.INTERSECT] = "INTERSECT DISTINCT" # type: ignore[attr-defined]

def __init__(self, dialect, statement, *args, **kwargs):
if isinstance(statement, Column):
kwargs["compile_kwargs"] = util.immutabledict({"include_table": False})
super(BigQueryCompiler, self).__init__(dialect, statement, *args, **kwargs)

def visit_insert(self, insert_stmt, asfrom=False, **kw):
# The signature of visit_insert is incompatible with supertype SQLCompiler in modern SQLAlchemy.
# mypy flags this as an override error, hence the ignore pragma.
def visit_insert(self, insert_stmt, asfrom=False, **kw): # type: ignore[override]
# The (internal) documentation for `inline` is confusing, but
# having `inline` be true prevents us from generating default
# primary-key values when we're doing executemany, which seem broken.
Expand Down Expand Up @@ -272,20 +278,23 @@ def visit_table_valued_alias(self, element, **kw):
def _known_tables(self):
known_tables = set()

for from_ in self.compile_state.froms:
if isinstance(from_, Table):
known_tables.add(from_.name)
elif isinstance(from_, CTE):
known_tables.add(from_.name)
for column in from_.original.selected_columns:
table = getattr(column, "table", None)
if table is not None:
known_tables.add(table.name)
if self.compile_state is not None:
# CompileState.froms is an internal SQLAlchemy attribute not exposed in stubs.
# mypy cannot verify this attribute, hence the ignore pragma.
for from_ in self.compile_state.froms: # type: ignore[attr-defined]
if isinstance(from_, Table):
known_tables.add(from_.name)
elif isinstance(from_, CTE):
known_tables.add(from_.name)
for column in getattr(from_.original, "selected_columns", []):
table = getattr(column, "table", None)
if table is not None:
known_tables.add(table.name)

# If we have the table in the `from` of our parent, do not add the alias
# as this will add the table twice and cause an implicit JOIN for that
# table on itself
asfrom_froms = self.stack[-1].get("asfrom_froms", [])
asfrom_froms: Any = self.stack[-1].get("asfrom_froms", [])
for from_ in asfrom_froms:
if isinstance(from_, Table):
known_tables.add(from_.name)
Expand All @@ -298,6 +307,7 @@ def visit_column(
add_to_result_map=None,
include_table=True,
result_map_targets=(),
ambiguous_table_name_map=None,
**kwargs,
):
name = orig_name = column.name
Expand All @@ -319,7 +329,10 @@ def visit_column(
if is_literal:
name = self.escape_literal_column(name)
else:
name = self.preparer.quote(name, column=True)
# The IdentifierPreparer.quote method in the base SQLAlchemy class
# does not accept a 'column' argument, but the subclass does.
# mypy flags this as an unexpected keyword, hence the ignore pragma.
name = self.preparer.quote(name, column=True) # type: ignore[call-arg]
table = column.table
if table is None or not include_table or not table.named_with_column:
return name
Expand Down Expand Up @@ -455,11 +468,11 @@ def visit_not_endswith_op_binary(self, binary, operator, **kw):

############################################################################

__placeholder = re.compile(r"%\(([^\]:]+)(:[^\]:]+)?\)s$").match
__placeholder = re.compile(r"%\(([^\]:]+)(:[^\]:]+)?\)s$")

__expanded_param = re.compile(
rf"\({__expanding_conflict}\[" rf"{__expanding_text}" rf"_[^\]]+\]\)$"
).match
)

__remove_type_parameter = _helpers.substitute_string_re_method(
r"""
Expand All @@ -479,6 +492,9 @@ def visit_bindparam(
within_columns_clause=False,
literal_binds=False,
skip_bind_expression=False,
literal_execute=False,
render_postcompile=False,
is_upsert_set=False,
**kwargs,
):
type_ = bindparam.type
Expand Down Expand Up @@ -523,9 +539,12 @@ def visit_bindparam(

param = super(BigQueryCompiler, self).visit_bindparam(
bindparam,
within_columns_clause,
literal_binds,
skip_bind_expression,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
skip_bind_expression=skip_bind_expression,
literal_execute=literal_execute,
render_postcompile=render_postcompile,
is_upsert_set=is_upsert_set,
**kwargs,
)

Expand All @@ -542,7 +561,7 @@ def visit_bindparam(
if type_.precision is None:
type_.precision = len(t.digits)

if type_.scale is None and t.exponent < 0:
if type_.scale is None and isinstance(t.exponent, int) and t.exponent < 0:
type_.scale = -t.exponent

bq_type = self.dialect.type_compiler.process(type_)
Expand All @@ -551,12 +570,12 @@ def visit_bindparam(
assert_(param != "%s", f"Unexpected param: {param}")

if bindparam.expanding: # pragma: NO COVER
assert_(self.__expanded_param(param), f"Unexpected param: {param}")
assert_(self.__expanded_param.match(param), f"Unexpected param: {param}")
if self.__sqlalchemy_version_info < packaging.version.parse("1.4.27"):
param = param.replace(")", f":{bq_type})")

else:
m = self.__placeholder(param)
m = self.__placeholder.match(param)
if m:
name, type_ = m.groups()
assert_(type_ is None)
Expand Down Expand Up @@ -902,7 +921,10 @@ def _process_time_partitioning(
)

# Generic Case
if partitioning_period not in allowed_partitions:
if (
allowed_partitions is not None
and partitioning_period not in allowed_partitions
):
raise ValueError(
"The TimePartitioning.type_ must be one of: "
f"{allowed_partitions}, received {partitioning_period}."
Expand Down Expand Up @@ -1048,7 +1070,8 @@ def process_array_literal(value):
class BigQueryDialect(DefaultDialect):
name = "bigquery"
driver = "bigquery"
preparer = BigQueryIdentifierPreparer
# Assigning a subclass to a class attribute that expects the base class causes invariance issues in mypy, hence the ignore pragma.
preparer = BigQueryIdentifierPreparer # type: ignore[assignment]
statement_compiler = BigQueryCompiler
type_compiler = BigQueryTypeCompiler
ddl_compiler = BigQueryDDLCompiler
Expand Down Expand Up @@ -1102,7 +1125,7 @@ def __init__(
self.project_id = None
self.billing_project_id = billing_project_id
self.location = location
self.identifier_preparer = self.preparer(self)
self.identifier_preparer = BigQueryIdentifierPreparer(self)
self.dataset_id = None
self.list_tables_page_size = list_tables_page_size

Expand All @@ -1128,7 +1151,7 @@ def create_job_config(self, provided_config: QueryJobConfig):
if self.dataset_id is None and project_id == self.billing_project_id:
return provided_config
job_config = provided_config or QueryJobConfig()
if project_id != self.billing_project_id:
if project_id and project_id != self.billing_project_id:
job_config.connection_properties = [
ConnectionProperty(key="dataset_project_id", value=project_id)
]
Expand Down Expand Up @@ -1391,15 +1414,14 @@ def __init__(self, *args, **kwargs):
raise TypeError("The unnest function requires a single argument.")
arg = args[0]
if isinstance(arg, sqlalchemy.sql.expression.ColumnElement):
if not (
isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY)
or (
hasattr(arg.type, "impl")
and isinstance(arg.type.impl, sqlalchemy.sql.sqltypes.ARRAY)
)
if isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY):
self.type = arg.type.item_type
elif hasattr(arg.type, "impl") and isinstance(
arg.type.impl, sqlalchemy.sql.sqltypes.ARRAY
):
self.type = arg.type.impl.item_type
else:
raise TypeError("The argument to unnest must have an ARRAY type.")
self.type = arg.type.item_type
super().__init__(*args, **kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,25 +481,27 @@ def test_custom_expression(
assert len(result) > 0


@pytest.mark.skipif(
SQLALCHEMY_VERSION >= packaging.version.parse("2.0"),
reason="Needs to be revisited as part of ensuring full SQL 2.0 compliance.",
)
def test_compiled_query_literal_binds(
engine, engine_using_test_dataset, table, table_using_test_dataset, query
):
q = query(table)
compiled = q.compile(engine, compile_kwargs={"literal_binds": True})
with engine.connect() as conn:
result = conn.execute(compiled).fetchall()
if hasattr(conn, "exec_driver_sql"):
result = conn.exec_driver_sql(str(compiled)).fetchall()
else:
result = conn.execute(compiled).fetchall()
assert len(result) > 0

q = query(table_using_test_dataset)
compiled = q.compile(
engine_using_test_dataset, compile_kwargs={"literal_binds": True}
)
with engine_using_test_dataset.connect() as conn:
result = conn.execute(compiled).fetchall()
if hasattr(conn, "exec_driver_sql"):
result = conn.exec_driver_sql(str(compiled)).fetchall()
else:
result = conn.execute(compiled).fetchall()
assert len(result) > 0


Expand Down
Loading
Loading