Skip to content

Commit 60b36ea

Browse files
authored
Merge pull request #46 from sede-open/release/release-2024-02-05-pk
Release/release 2024 02 05 pk
2 parents 0395eb3 + 861584b commit 60b36ea

2 files changed

Lines changed: 78 additions & 17 deletions

File tree

src/databricks/sqlalchemy/dialect/__init__.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from dateutil.parser import parse
66

77
from sqlalchemy import types, processors, event
8-
from sqlalchemy.engine import default, Engine
8+
from sqlalchemy.engine import default, Engine, reflection
99
from sqlalchemy.exc import DatabaseError
10-
from sqlalchemy.engine import reflection
1110

1211
from databricks import sql
1312

@@ -182,24 +181,45 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
182181

183182
return columns
184183

184+
@reflection.cache
185185
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
186-
"""Return information about the primary key constraint on
187-
table_name`.
188-
189-
Given a :class:`_engine.Connection`, a string
190-
`table_name`, and an optional string `schema`, return primary
191-
key information as a dictionary with these keys:
186+
"""Fetch information about the primary key constraint on table_name.
192187
193-
constrained_columns
194-
a list of column names that make up the primary key
188+
Returns a dictionary with these keys:
189+
constrained_columns
190+
a list of column names that make up the primary key. Results is an empty list
191+
if no PRIMARY KEY is defined.
195192
196-
name
197-
optional name of the primary key constraint.
193+
name
194+
the name of the primary key constraint
198195
199196
"""
200-
# TODO: implement this behaviour
201-
return {"constrained_columns": []}
197+
# TODO: abstract this to databricks.sql.client
198+
CONSTRAINT_NAME = 1
199+
COLUMN_NAME = 2
200+
201+
with self.get_driver_connection(
202+
connection
203+
)._dbapi_connection.dbapi_connection.cursor() as cur:
204+
pk_query = """
205+
SELECT table_name, constraint_name, column_name
206+
FROM information_schema.constraint_column_usage
207+
WHERE table_schema = '{schema}'
208+
AND table_name = '{table}'
209+
AND constraint_name LIKE 'pk_%'
210+
""".format(
211+
schema=schema,
212+
table=table_name
213+
)
214+
215+
data = cur.execute(pk_query).fetchall()
216+
217+
cols = [i[COLUMN_NAME] for i in data]
218+
name = [i[CONSTRAINT_NAME] for i in data]
219+
220+
return {"constrained_columns": cols, "name": name}
202221

222+
@reflection.cache
203223
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
204224
"""Return information about foreign_keys in `table_name`.
205225
@@ -223,8 +243,46 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
223243
a list of column names in the referred table that correspond to
224244
constrained_columns
225245
"""
226-
# TODO: Implement this behaviour
227-
return []
246+
# TODO: abstract this to databricks.sql.client
247+
# TODO: can only process 1:1 FK relationships
248+
CONSTRAINT_NAME = 2
249+
COLUMN_NAME = 3
250+
CONTRAINT_SCHEMA = 0
251+
TABLE_NAME = 1
252+
253+
with self.get_driver_connection(
254+
connection
255+
)._dbapi_connection.dbapi_connection.cursor() as cur:
256+
fk_query = """
257+
SELECT constraint_schema, table_name, constraint_name, column_name
258+
FROM information_schema.constraint_column_usage
259+
WHERE table_schema = '{schema}'
260+
AND table_name = '{table}'
261+
AND constraint_name LIKE 'fk_%'
262+
""".format(
263+
schema=schema,
264+
table=table_name
265+
)
266+
267+
data = cur.execute(fk_query).fetchall()
268+
269+
fkeys = []
270+
for fk in data:
271+
name = fk[CONSTRAINT_NAME]
272+
col_name = fk[COLUMN_NAME]
273+
con_schema = fk[CONTRAINT_SCHEMA]
274+
table = fk[TABLE_NAME]
275+
276+
fkey_d = {
277+
"name": name,
278+
"constrained_columns": col_name,
279+
"referred_schema": con_schema,
280+
"referred_table": "charger_evse", # TODO: Replace, hardcode for testing
281+
"referred_columns": col_name,
282+
}
283+
fkeys.append(fkey_d)
284+
285+
return fkeys
228286

229287
def get_indexes(self, connection, table_name, schema=None, **kw):
230288
"""Return information about indexes in `table_name`.

src/databricks/sqlalchemy/dialect/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,11 @@ def visit_create_table(self, create, **kw):
101101
processed = self.process(
102102
create_column, first_pk=column.primary_key and not first_pk
103103
)
104+
# Add backquotes to column names if there are none - assumes no spaces in column name
105+
if '`' not in processed:
106+
processed = '`' + "".join(processed.split(" ")[0]) + "` " + " ".join(processed.split(" ")[1:])
104107
if column.autoincrement is True: # If doesn't work try 'is True' and == 'True'
105-
processed = "`".join(processed.split("`")[:-1]) + "` " + "BIGINT GENERATED ALWAYS AS IDENTITY"
108+
processed = "`".join(processed.split("`")[:-1]) + "` " + "BIGINT GENERATED BY DEFAULT AS IDENTITY"
106109
if processed is not None:
107110
text += separator
108111
separator = ", \n"

0 commit comments

Comments
 (0)