|
| 1 | +import json |
| 2 | +import sqlite3 |
| 3 | +from typing import Any, Callable, Dict, List, Optional, Type, Union |
| 4 | + |
| 5 | +from sqlalchemy import create_engine as _create_engine |
| 6 | +from sqlalchemy.engine.url import URL |
| 7 | +from sqlalchemy.future import Engine as _FutureEngine |
| 8 | +from sqlalchemy.pool import Pool |
| 9 | +from typing_extensions import Literal, TypedDict |
| 10 | + |
| 11 | +from ..default import Default, _DefaultPlaceholder |
| 12 | + |
| 13 | +# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here |
| 14 | + |
| 15 | +_Debug = Literal["debug"] |
| 16 | + |
| 17 | +_IsolationLevel = Literal[ |
| 18 | + "SERIALIZABLE", |
| 19 | + "REPEATABLE READ", |
| 20 | + "READ COMMITTED", |
| 21 | + "READ UNCOMMITTED", |
| 22 | + "AUTOCOMMIT", |
| 23 | +] |
| 24 | +_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"] |
| 25 | +_ResetOnReturn = Literal["rollback", "commit"] |
| 26 | + |
| 27 | + |
| 28 | +class _SQLiteConnectArgs(TypedDict, total=False): |
| 29 | + timeout: float |
| 30 | + detect_types: Any |
| 31 | + isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] |
| 32 | + check_same_thread: bool |
| 33 | + factory: Type[sqlite3.Connection] |
| 34 | + cached_statements: int |
| 35 | + uri: bool |
| 36 | + |
| 37 | + |
| 38 | +_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]] |
| 39 | + |
| 40 | + |
| 41 | +# Re-define create_engine to have by default future=True, and assume that's what is used |
| 42 | +# Also show the default values used for each parameter, but don't set them unless |
| 43 | +# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't |
| 44 | +# support pool connection arguments. |
| 45 | +def create_engine( |
| 46 | + url: Union[str, URL], |
| 47 | + *, |
| 48 | + connect_args: _ConnectArgs = Default({}), # type: ignore |
| 49 | + echo: Union[bool, _Debug] = Default(False), |
| 50 | + echo_pool: Union[bool, _Debug] = Default(False), |
| 51 | + enable_from_linting: bool = Default(True), |
| 52 | + encoding: str = Default("utf-8"), |
| 53 | + execution_options: Dict[Any, Any] = Default({}), |
| 54 | + future: bool = True, |
| 55 | + hide_parameters: bool = Default(False), |
| 56 | + implicit_returning: bool = Default(True), |
| 57 | + isolation_level: Optional[_IsolationLevel] = Default(None), |
| 58 | + json_deserializer: Callable[..., Any] = Default(json.loads), |
| 59 | + json_serializer: Callable[..., Any] = Default(json.dumps), |
| 60 | + label_length: Optional[int] = Default(None), |
| 61 | + logging_name: Optional[str] = Default(None), |
| 62 | + max_identifier_length: Optional[int] = Default(None), |
| 63 | + max_overflow: int = Default(10), |
| 64 | + module: Optional[Any] = Default(None), |
| 65 | + paramstyle: Optional[_ParamStyle] = Default(None), |
| 66 | + pool: Optional[Pool] = Default(None), |
| 67 | + poolclass: Optional[Type[Pool]] = Default(None), |
| 68 | + pool_logging_name: Optional[str] = Default(None), |
| 69 | + pool_pre_ping: bool = Default(False), |
| 70 | + pool_size: int = Default(5), |
| 71 | + pool_recycle: int = Default(-1), |
| 72 | + pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"), |
| 73 | + pool_timeout: float = Default(30), |
| 74 | + pool_use_lifo: bool = Default(False), |
| 75 | + plugins: Optional[List[str]] = Default(None), |
| 76 | + query_cache_size: Optional[int] = Default(None), |
| 77 | + **kwargs: Any, |
| 78 | +) -> _FutureEngine: |
| 79 | + current_kwargs: Dict[str, Any] = { |
| 80 | + "future": future, |
| 81 | + } |
| 82 | + if not isinstance(echo, _DefaultPlaceholder): |
| 83 | + current_kwargs["echo"] = echo |
| 84 | + if not isinstance(echo_pool, _DefaultPlaceholder): |
| 85 | + current_kwargs["echo_pool"] = echo_pool |
| 86 | + if not isinstance(enable_from_linting, _DefaultPlaceholder): |
| 87 | + current_kwargs["enable_from_linting"] = enable_from_linting |
| 88 | + if not isinstance(connect_args, _DefaultPlaceholder): |
| 89 | + current_kwargs["connect_args"] = connect_args |
| 90 | + if not isinstance(encoding, _DefaultPlaceholder): |
| 91 | + current_kwargs["encoding"] = encoding |
| 92 | + if not isinstance(execution_options, _DefaultPlaceholder): |
| 93 | + current_kwargs["execution_options"] = execution_options |
| 94 | + if not isinstance(hide_parameters, _DefaultPlaceholder): |
| 95 | + current_kwargs["hide_parameters"] = hide_parameters |
| 96 | + if not isinstance(implicit_returning, _DefaultPlaceholder): |
| 97 | + current_kwargs["implicit_returning"] = implicit_returning |
| 98 | + if not isinstance(isolation_level, _DefaultPlaceholder): |
| 99 | + current_kwargs["isolation_level"] = isolation_level |
| 100 | + if not isinstance(json_deserializer, _DefaultPlaceholder): |
| 101 | + current_kwargs["json_deserializer"] = json_deserializer |
| 102 | + if not isinstance(json_serializer, _DefaultPlaceholder): |
| 103 | + current_kwargs["json_serializer"] = json_serializer |
| 104 | + if not isinstance(label_length, _DefaultPlaceholder): |
| 105 | + current_kwargs["label_length"] = label_length |
| 106 | + if not isinstance(logging_name, _DefaultPlaceholder): |
| 107 | + current_kwargs["logging_name"] = logging_name |
| 108 | + if not isinstance(max_identifier_length, _DefaultPlaceholder): |
| 109 | + current_kwargs["max_identifier_length"] = max_identifier_length |
| 110 | + if not isinstance(max_overflow, _DefaultPlaceholder): |
| 111 | + current_kwargs["max_overflow"] = max_overflow |
| 112 | + if not isinstance(module, _DefaultPlaceholder): |
| 113 | + current_kwargs["module"] = module |
| 114 | + if not isinstance(paramstyle, _DefaultPlaceholder): |
| 115 | + current_kwargs["paramstyle"] = paramstyle |
| 116 | + if not isinstance(pool, _DefaultPlaceholder): |
| 117 | + current_kwargs["pool"] = pool |
| 118 | + if not isinstance(poolclass, _DefaultPlaceholder): |
| 119 | + current_kwargs["poolclass"] = poolclass |
| 120 | + if not isinstance(pool_logging_name, _DefaultPlaceholder): |
| 121 | + current_kwargs["pool_logging_name"] = pool_logging_name |
| 122 | + if not isinstance(pool_pre_ping, _DefaultPlaceholder): |
| 123 | + current_kwargs["pool_pre_ping"] = pool_pre_ping |
| 124 | + if not isinstance(pool_size, _DefaultPlaceholder): |
| 125 | + current_kwargs["pool_size"] = pool_size |
| 126 | + if not isinstance(pool_recycle, _DefaultPlaceholder): |
| 127 | + current_kwargs["pool_recycle"] = pool_recycle |
| 128 | + if not isinstance(pool_reset_on_return, _DefaultPlaceholder): |
| 129 | + current_kwargs["pool_reset_on_return"] = pool_reset_on_return |
| 130 | + if not isinstance(pool_timeout, _DefaultPlaceholder): |
| 131 | + current_kwargs["pool_timeout"] = pool_timeout |
| 132 | + if not isinstance(pool_use_lifo, _DefaultPlaceholder): |
| 133 | + current_kwargs["pool_use_lifo"] = pool_use_lifo |
| 134 | + if not isinstance(plugins, _DefaultPlaceholder): |
| 135 | + current_kwargs["plugins"] = plugins |
| 136 | + if not isinstance(query_cache_size, _DefaultPlaceholder): |
| 137 | + current_kwargs["query_cache_size"] = query_cache_size |
| 138 | + current_kwargs.update(kwargs) |
| 139 | + return _create_engine(url, **current_kwargs) |
0 commit comments