Skip to content

Commit ec4c15c

Browse files
authored
fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668. (#4065)
* fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668. Signed-off-by: Shuchu Han <shuchu.han@gmail.com> * fix: fix typo. Signed-off-by: Shuchu Han <shuchu.han@gmail.com> --------- Signed-off-by: Shuchu Han <shuchu.han@gmail.com>
1 parent 7f1557b commit ec4c15c

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

sdk/python/feast/infra/registry/sql.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def teardown(self):
205205
saved_datasets,
206206
validation_references,
207207
}:
208-
with self.engine.connect() as conn:
208+
with self.engine.begin() as conn:
209209
stmt = delete(t)
210210
conn.execute(stmt)
211211

@@ -399,7 +399,7 @@ def apply_feature_service(
399399
)
400400

401401
def delete_data_source(self, name: str, project: str, commit: bool = True):
402-
with self.engine.connect() as conn:
402+
with self.engine.begin() as conn:
403403
stmt = delete(data_sources).where(
404404
data_sources.c.data_source_name == name,
405405
data_sources.c.project_id == project,
@@ -441,16 +441,19 @@ def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureVie
441441
)
442442

443443
def _list_project_metadata(self, project: str) -> List[ProjectMetadata]:
444-
with self.engine.connect() as conn:
444+
with self.engine.begin() as conn:
445445
stmt = select(feast_metadata).where(
446446
feast_metadata.c.project_id == project,
447447
)
448448
rows = conn.execute(stmt).all()
449449
if rows:
450450
project_metadata = ProjectMetadata(project_name=project)
451451
for row in rows:
452-
if row["metadata_key"] == FeastMetadataKeys.PROJECT_UUID.value:
453-
project_metadata.project_uuid = row["metadata_value"]
452+
if (
453+
row._mapping["metadata_key"]
454+
== FeastMetadataKeys.PROJECT_UUID.value
455+
):
456+
project_metadata.project_uuid = row._mapping["metadata_value"]
454457
break
455458
# TODO(adchia): Add other project metadata in a structured way
456459
return [project_metadata]
@@ -557,7 +560,7 @@ def apply_user_metadata(
557560
table = self._infer_fv_table(feature_view)
558561

559562
name = feature_view.name
560-
with self.engine.connect() as conn:
563+
with self.engine.begin() as conn:
561564
stmt = select(table).where(
562565
getattr(table.c, "feature_view_name") == name,
563566
table.c.project_id == project,
@@ -612,11 +615,11 @@ def get_user_metadata(
612615
table = self._infer_fv_table(feature_view)
613616

614617
name = feature_view.name
615-
with self.engine.connect() as conn:
618+
with self.engine.begin() as conn:
616619
stmt = select(table).where(getattr(table.c, "feature_view_name") == name)
617620
row = conn.execute(stmt).first()
618621
if row:
619-
return row["user_metadata"]
622+
return row._mapping["user_metadata"]
620623
else:
621624
raise FeatureViewNotFoundException(feature_view.name, project=project)
622625

@@ -674,7 +677,7 @@ def _apply_object(
674677
name = name or (obj.name if hasattr(obj, "name") else None)
675678
assert name, f"name needs to be provided for {obj}"
676679

677-
with self.engine.connect() as conn:
680+
with self.engine.begin() as conn:
678681
update_datetime = datetime.utcnow()
679682
update_time = int(update_datetime.timestamp())
680683
stmt = select(table).where(
@@ -723,7 +726,7 @@ def _apply_object(
723726

724727
def _maybe_init_project_metadata(self, project):
725728
# Initialize project metadata if needed
726-
with self.engine.connect() as conn:
729+
with self.engine.begin() as conn:
727730
update_datetime = datetime.utcnow()
728731
update_time = int(update_datetime.timestamp())
729732
stmt = select(feast_metadata).where(
@@ -732,7 +735,7 @@ def _maybe_init_project_metadata(self, project):
732735
)
733736
row = conn.execute(stmt).first()
734737
if row:
735-
usage.set_current_project_uuid(row["metadata_value"])
738+
usage.set_current_project_uuid(row._mapping["metadata_value"])
736739
else:
737740
new_project_uuid = f"{uuid.uuid4()}"
738741
values = {
@@ -753,7 +756,7 @@ def _delete_object(
753756
id_field_name: str,
754757
not_found_exception: Optional[Callable],
755758
):
756-
with self.engine.connect() as conn:
759+
with self.engine.begin() as conn:
757760
stmt = delete(table).where(
758761
getattr(table.c, id_field_name) == name, table.c.project_id == project
759762
)
@@ -777,13 +780,13 @@ def _get_object(
777780
):
778781
self._maybe_init_project_metadata(project)
779782

780-
with self.engine.connect() as conn:
783+
with self.engine.begin() as conn:
781784
stmt = select(table).where(
782785
getattr(table.c, id_field_name) == name, table.c.project_id == project
783786
)
784787
row = conn.execute(stmt).first()
785788
if row:
786-
_proto = proto_class.FromString(row[proto_field_name])
789+
_proto = proto_class.FromString(row._mapping[proto_field_name])
787790
return python_class.from_proto(_proto)
788791
if not_found_exception:
789792
raise not_found_exception(name, project)
@@ -799,20 +802,20 @@ def _list_objects(
799802
proto_field_name: str,
800803
):
801804
self._maybe_init_project_metadata(project)
802-
with self.engine.connect() as conn:
805+
with self.engine.begin() as conn:
803806
stmt = select(table).where(table.c.project_id == project)
804807
rows = conn.execute(stmt).all()
805808
if rows:
806809
return [
807810
python_class.from_proto(
808-
proto_class.FromString(row[proto_field_name])
811+
proto_class.FromString(row._mapping[proto_field_name])
809812
)
810813
for row in rows
811814
]
812815
return []
813816

814817
def _set_last_updated_metadata(self, last_updated: datetime, project: str):
815-
with self.engine.connect() as conn:
818+
with self.engine.begin() as conn:
816819
stmt = select(feast_metadata).where(
817820
feast_metadata.c.metadata_key
818821
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
@@ -846,7 +849,7 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str):
846849
conn.execute(insert_stmt)
847850

848851
def _get_last_updated_metadata(self, project: str):
849-
with self.engine.connect() as conn:
852+
with self.engine.begin() as conn:
850853
stmt = select(feast_metadata).where(
851854
feast_metadata.c.metadata_key
852855
== FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value,
@@ -855,13 +858,13 @@ def _get_last_updated_metadata(self, project: str):
855858
row = conn.execute(stmt).first()
856859
if not row:
857860
return None
858-
update_time = int(row["last_updated_timestamp"])
861+
update_time = int(row._mapping["last_updated_timestamp"])
859862

860863
return datetime.utcfromtimestamp(update_time)
861864

862865
def _get_all_projects(self) -> Set[str]:
863866
projects = set()
864-
with self.engine.connect() as conn:
867+
with self.engine.begin() as conn:
865868
for table in {
866869
entities,
867870
data_sources,
@@ -872,6 +875,6 @@ def _get_all_projects(self) -> Set[str]:
872875
stmt = select(table)
873876
rows = conn.execute(stmt).all()
874877
for row in rows:
875-
projects.add(row["project_id"])
878+
projects.add(row._mapping["project_id"])
876879

877880
return projects

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"pygments>=2.12.0,<3",
5858
"PyYAML>=5.4.0,<7",
5959
"requests",
60-
"SQLAlchemy[mypy]>1,<2",
60+
"SQLAlchemy[mypy]>1",
6161
"tabulate>=0.8.0,<1",
6262
"tenacity>=7,<9",
6363
"toml>=0.10.0,<1",

0 commit comments

Comments
 (0)