Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions registry/sql-registry/registry/db_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):

def get_projects(self) -> list[str]:
ret = self.conn.query(
f"select qualified_name from entities where entity_type='{EntityType.Project}'")
f"select qualified_name from entities where entity_type=%s", EntityType.Project)
return list([r["qualified_name"] for r in ret])

def get_entity(self, id_or_name: Union[str, UUID]) -> Entity:
Expand All @@ -40,16 +40,16 @@ def get_entity_id(self, id_or_name: Union[str, UUID]) -> UUID:
pass
# It is a name
ret = self.conn.query(
f"select entity_id from entities where qualified_name='{id_or_name}'")
f"select entity_id from entities where qualified_name=%s", id_or_name)
return ret[0]["entity_id"]

def get_neighbors(self, id_or_name: Union[str, UUID], relationship: RelationshipType) -> list[Edge]:
rows = self.conn.query(fr'''
select edge_id, from_id, to_id, conn_type
from edges
where from_id = '{self.get_entity_id(id_or_name)}'
and conn_type = '{relationship.name}'
''')
where from_id = %s
and conn_type = %s
''', (self.get_entity_id(id_or_name), relationship.name))
return list([Edge(**row) for row in rows])

def get_lineage(self, id_or_name: Union[str, UUID]) -> EntitiesAndRelations:
Expand Down Expand Up @@ -100,9 +100,8 @@ def search_entity(self,
"""
WARN: This search function is implemented via `like` operator, which could be extremely slow.
"""
types = ",".join([quote(str(t)) for t in type])
sql = fr'''select entity_id as id, qualified_name, entity_type as type from entities where qualified_name like %s and entity_type in ({types})'''
rows = self.conn.query(sql, ('%' + keyword + '%', ))
sql = fr'''select entity_id as id, qualified_name, entity_type as type from entities where qualified_name like %s and entity_type in %s'''
rows = self.conn.query(sql, ('%' + keyword + '%', tuple([str(t) for t in type])))
return list([EntityRef(**row) for row in rows])

def create_project(self, definition: ProjectDef) -> UUID:
Expand Down Expand Up @@ -304,7 +303,7 @@ def create_project_derived_feature(self, project_id: UUID, definition: DerivedFe
# Fill `input_anchor_features`, from `definition` we have ids only, we still need qualified names
if definition.input_anchor_features:
c.execute(
fr'''select entity_id, entity_type, qualified_name from entities where entity_id in ({quote(definition.input_anchor_features)}) and entity_type = %s ''', str(EntityType.AnchorFeature))
fr'''select entity_id, entity_type, qualified_name from entities where entity_id in %s and entity_type = %s ''', (tuple([str(id) for id in definition.input_anchor_features]), str(EntityType.AnchorFeature)))
r1 = c.fetchall()
if len(r1) != len(definition.input_anchor_features):
# TODO: More detailed error
Expand All @@ -313,7 +312,7 @@ def create_project_derived_feature(self, project_id: UUID, definition: DerivedFe
r2 = []
if definition.input_derived_features:
c.execute(
fr'''select entity_id, entity_type, qualified_name from entities where entity_id in ({quote(definition.input_derived_features)}) and entity_type = %s ''', str(EntityType.DerivedFeature))
fr'''select entity_id, entity_type, qualified_name from entities where entity_id in %s and entity_type = %s ''', (tuple([str(id) for id in definition.input_anchor_features]), str(EntityType.DerivedFeature)))
r2 = c.fetchall()
if len(r2) != len(definition.input_derived_features):
# TODO: More detailed error
Expand Down Expand Up @@ -387,22 +386,25 @@ def _fill_entity(self, e: Entity) -> Entity:

def _get_edges(self, ids: list[UUID], types: list[RelationshipType] = []) -> list[Edge]:
sql = fr"""select edge_id, from_id, to_id, conn_type from edges
where from_id in ({quote(ids)})
and to_id in ({quote(ids)})"""
where from_id in %(ids)s
and to_id in %(ids)s"""
if len(types) > 0:
sql = fr"""select edge_id, from_id, to_id, conn_type from edges
where conn_type in ({quote(types)})
and from_id in ({quote(ids)})
and to_id in ({quote(ids)})"""
rows = self.conn.query(sql)
where conn_type in %(types)s
and from_id in %(ids)s
and to_id in %(ids)s"""
rows = self.conn.query(sql, {
"ids": tuple([str(id) for id in ids]),
"types": tuple([str(t) for t in types]),
})
return list([_to_type(row, Edge) for row in rows])

def _get_entity(self, id_or_name: Union[str, UUID]) -> Entity:
row = self.conn.query(fr'''
select entity_id, qualified_name, entity_type, attributes
from entities
where entity_id = '{self.get_entity_id(id_or_name)}'
''')[0]
where entity_id = %s
''', self.get_entity_id(id_or_name))[0]
row["attributes"] = json.loads(row["attributes"])
return _to_type(row, Entity)

Expand All @@ -411,8 +413,8 @@ def _get_entities(self, ids: list[UUID]) -> list[Entity]:
return []
rows = self.conn.query(fr'''select entity_id, qualified_name, entity_type, attributes
from entities
where entity_id in ({quote(ids)})
''')
where entity_id in %s
''', (tuple([str(id) for id in ids]), ))
ret = []
for row in rows:
row["attributes"] = json.loads(row["attributes"])
Expand Down Expand Up @@ -448,5 +450,5 @@ def _bfs_step(self, ids: list[UUID], conn_type: RelationshipType) -> set[dict]:
Returns all edges that connect to node ids the next step
"""
ids = list([id["to_id"] for id in ids])
sql = fr"""select edge_id, from_id, to_id, conn_type from edges where conn_type = '{conn_type.name}' and from_id in ({quote(ids)})"""
return self.conn.query(sql)
sql = fr"""select edge_id, from_id, to_id, conn_type from edges where conn_type = %s and from_id in %s"""
return self.conn.query(sql, (conn_type.name, tuple([str(id) for id in ids])))
3 changes: 3 additions & 0 deletions registry/sql-registry/test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import registry
from registry.models import EntityType
r=registry.DbRegistry()

l=r.get_lineage('226b42ee-0c34-4329-b935-744aecc63fb4').to_dict()
Expand All @@ -15,3 +16,5 @@
assert(len(p.to_dict()['guidEntityMap'])==14)

es=r.search_entity("time", [registry.EntityType.DerivedFeature])
qns=set([e.qualified_name for e in es])
assert qns == set(['feathr_ci_registry_12_33_182947__f_trip_time_distance', 'feathr_ci_registry_12_33_182947__f_trip_time_rounded', 'feathr_ci_registry_12_33_182947__f_trip_time_rounded_plus'])