|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import copy |
4 | 3 | import datetime |
5 | 4 | import decimal |
6 | 5 | from abc import ABC, abstractmethod |
@@ -495,6 +494,7 @@ class DbSqlType(Enum): |
495 | 494 | BOOLEAN = "BOOLEAN" |
496 | 495 | INTERVAL_MONTH = "INTERVAL MONTH" |
497 | 496 | INTERVAL_DAY = "INTERVAL DAY" |
| 497 | + VOID = "VOID" |
498 | 498 |
|
499 | 499 |
|
500 | 500 | class DbSqlParameter: |
@@ -542,20 +542,41 @@ def infer_types(params: list[DbSqlParameter]): |
542 | 542 | datetime.date: DbSqlType.DATE, |
543 | 543 | bool: DbSqlType.BOOLEAN, |
544 | 544 | Decimal: DbSqlType.DECIMAL, |
| 545 | + type(None): DbSqlType.VOID, |
545 | 546 | } |
546 | | - new_params = copy.deepcopy(params) |
547 | | - for param in new_params: |
548 | | - if not param.type: |
549 | | - if type(param.value) in type_lookup_table: |
550 | | - param.type = type_lookup_table[type(param.value)] |
551 | | - else: |
552 | | - raise ValueError("Parameter type cannot be inferred") |
553 | | - |
554 | | - if param.type == DbSqlType.DECIMAL: |
| 547 | + |
| 548 | + new_params = [] |
| 549 | + |
| 550 | + # cycle through each parameter we've been passed |
| 551 | + for param in params: |
| 552 | + _name: str = param.name |
| 553 | + _value: Any = param.value |
| 554 | + _type: Union[DbSqlType, DbsqlDynamicDecimalType, Enum, None] |
| 555 | + |
| 556 | + if param.type: |
| 557 | + _type = param.type |
| 558 | + else: |
| 559 | + # figure out what type to use |
| 560 | + _type = type_lookup_table.get(type(_value), None) |
| 561 | + if not _type: |
| 562 | + raise ValueError( |
| 563 | + f"Could not infer parameter type from {type(param.value)} - {param.value}" |
| 564 | + ) |
| 565 | + |
| 566 | + # Decimal require special handling because one column type in Databricks can have multiple precisions |
| 567 | + if _type == DbSqlType.DECIMAL: |
555 | 568 | cast_exp = calculate_decimal_cast_string(param.value) |
556 | | - param.type = DbsqlDynamicDecimalType(cast_exp) |
| 569 | + _type = DbsqlDynamicDecimalType(cast_exp) |
| 570 | + |
| 571 | + # VOID / NULL types must be passed in a unique way as TSparkParameters with no value |
| 572 | + if _type == DbSqlType.VOID: |
| 573 | + new_params.append(DbSqlParameter(name=_name, type=DbSqlType.VOID)) |
| 574 | + continue |
| 575 | + else: |
| 576 | + _value = str(param.value) |
| 577 | + |
| 578 | + new_params.append(DbSqlParameter(name=_name, value=_value, type=_type)) |
557 | 579 |
|
558 | | - param.value = str(param.value) |
559 | 580 | return new_params |
560 | 581 |
|
561 | 582 |
|
@@ -594,11 +615,15 @@ def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str] |
594 | 615 | dbsql_params = named_parameters_to_dbsqlparams_v2(parameters) |
595 | 616 | inferred_type_parameters = infer_types(dbsql_params) |
596 | 617 | for param in inferred_type_parameters: |
597 | | - tspark_params.append( |
598 | | - TSparkParameter( |
599 | | - type=param.type.value, |
600 | | - name=param.name, |
601 | | - value=TSparkParameterValue(stringValue=param.value), |
| 618 | + # The only way to pass a VOID/NULL to DBR is to declare TSparkParameter without declaring |
| 619 | + # its value or type arguments. If we set these to NoneType, the request will fail with a |
| 620 | + # thrift transport error |
| 621 | + if param.type == DbSqlType.VOID: |
| 622 | + this_tspark_param = TSparkParameter(name=param.name) |
| 623 | + else: |
| 624 | + this_tspark_param_value = TSparkParameterValue(stringValue=param.value) |
| 625 | + this_tspark_param = TSparkParameter( |
| 626 | + type=param.type.value, name=param.name, value=this_tspark_param_value |
602 | 627 | ) |
603 | | - ) |
| 628 | + tspark_params.append(this_tspark_param) |
604 | 629 | return tspark_params |
0 commit comments