From 17a7ace869c6ebbae25222baf558e743b1f9762f Mon Sep 17 00:00:00 2001 From: Derek Brandao Date: Wed, 13 Mar 2019 14:08:03 +0000 Subject: [PATCH 1/3] Make model specification and loading more robust Right now, in relationships and parent table specifications, specifying the model name requires a fully-specified class name. Adding a registry so we don't have to load classes on the fly and instead register classes as they load, so we can specify a less-qualified class name and avoid having to read files to resolve a model class --- spanner_orm/admin/metadata.py | 28 +++++++------- spanner_orm/model.py | 22 +++-------- spanner_orm/registry.py | 67 +++++++++++++++++++++++++++++++++ spanner_orm/relationship.py | 5 ++- spanner_orm/tests/model_test.py | 3 ++ spanner_orm/tests/models.py | 2 +- 6 files changed, 93 insertions(+), 34 deletions(-) create mode 100644 spanner_orm/registry.py diff --git a/spanner_orm/admin/metadata.py b/spanner_orm/admin/metadata.py index 482551f..718ef06 100644 --- a/spanner_orm/admin/metadata.py +++ b/spanner_orm/admin/metadata.py @@ -29,6 +29,10 @@ class SpannerMetadata(object): """Gathers information about a table from Spanner.""" + @classmethod + def _class_name_from_table(cls, table_name): + return 'table_{}_model'.format(table_name) + @classmethod def models(cls): """Constructs model classes from Spanner table schema.""" @@ -39,24 +43,20 @@ def models(cls): for table_name, table_data in tables.items(): primary_index = indexes[table_name][index.Index.PRIMARY_INDEX] primary_keys = set(primary_index.columns) - klass = model.ModelBase('Model_{}'.format(table_name), (model.Model,), - {}) + klass = model.ModelBase( + cls._class_name_from_table(table_name), (model.Model,), {}) for model_field in table_data['fields'].values(): model_field._primary_key = model_field.name in primary_keys # pylint: disable=protected-access klass.meta = model.Metadata( table=table_name, fields=table_data['fields'], - interleaved=table_data['parent_table'], + interleaved=cls._class_name_from_table(table_data['parent_table']), indexes=indexes[table_name], model_class=klass) + klass.meta.finalize() models[table_name] = klass - for table_model in models.values(): - if table_model.meta.interleaved: - table_model.meta.interleaved = models[table_model.meta.interleaved] - table_model.meta.finalize() - return models @classmethod @@ -67,9 +67,9 @@ def model(cls, table_name): def tables(cls): """Compiles table information from column schema.""" column_data = collections.defaultdict(dict) - columns = column.ColumnSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + columns = column.ColumnSchema.where(None, + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', '')) for column_row in columns: new_field = field.Field( column_row.field_type(), nullable=column_row.nullable()) @@ -78,9 +78,9 @@ def tables(cls): column_data[column_row.table_name][column_row.column_name] = new_field table_data = collections.defaultdict(dict) - tables = table.TableSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + tables = table.TableSchema.where(None, + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', '')) for table_row in tables: name = table_row.table_name table_data[name]['parent_table'] = table_row.parent_table_name diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 0f934e7..c818421 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -16,7 +16,6 @@ import collections import copy -import importlib from spanner_orm import api from spanner_orm import condition @@ -24,6 +23,7 @@ from spanner_orm import field from spanner_orm import index from spanner_orm import query +from spanner_orm import registry from spanner_orm import relationship from google.cloud import spanner @@ -69,6 +69,7 @@ def finalize(self): for _, relation in self.relations.items(): relation.origin = self.model_class + registry.model_registry().register(self.model_class) self._finalized = True def add_metadata(self, metadata): @@ -153,9 +154,9 @@ def indexes(cls): @property def interleaved(cls): - if cls.meta.interleaved and not isinstance(cls.meta.interleaved, ModelBase): - cls.meta.interleaved = load_model(cls.meta.interleaved) - return cls.meta.interleaved + if cls.meta.interleaved: + return registry.model_registry().get(cls.meta.interleaved) + return None @property def primary_keys(cls): @@ -426,16 +427,3 @@ def save(self, transaction=None): self._metaclass.create(transaction, **self.values) self._persisted = True return self - - -def load_model(model_handle): - if isinstance(model_handle, Model): - return model_handle - parts = model_handle.split('.') - path = '.'.join(parts[:-1]) - module = importlib.import_module(path) - klass = getattr(module, parts[-1]) - if not issubclass(klass, Model): - raise error.SpannerError( - '{model} is not a Model'.format(model=model_handle)) - return klass diff --git a/spanner_orm/registry.py b/spanner_orm/registry.py new file mode 100644 index 0000000..1c596cc --- /dev/null +++ b/spanner_orm/registry.py @@ -0,0 +1,67 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Registers Model classes so they can be referenced elsewhere.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Type, Union + +import dataclasses +from spanner_orm import error + + +@dataclasses.dataclass +class RegistryComponent: + references: List[Type[Any]] = dataclasses.field(default_factory=list) + + def add(self, reference: Type[Any]) -> None: + self.references.append(reference) + + +class Registry(object): + + def __init__(self): + self._registered = {} # type: Dict[str, RegistryComponent] + + def _name_from_class(self, klass: Type[Any]) -> str: + return '{}.{}'.format(klass.__module__, klass.__name__) + + def register(self, to_register: Type[Any]) -> None: + name_components = reversed(self._name_from_class(to_register).split('.')) + name = None + for component in name_components: + name = name = '{}.{}'.format(component, name) if name else component + if name not in self._registered: + self._registered[name] = RegistryComponent() + self._registered[name].add(to_register) + + def get(self, name: Union[Type[Any], str]) -> Type[Any]: + if isinstance(name, type): + name = self._name_from_class(name) + + if name not in self._registered: + raise error.SpannerError( + '{} was not found, verify it has been imported'.format(name)) + if len(self._registered[name].references) > 1: + raise error.SpannerError( + 'Multiple classes match {}, add more specificity'.format(name)) + return self._registered[name].references[0] + + +_registry = Registry() + + +def model_registry(): + return _registry diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index f44d413..4d396b3 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -21,6 +21,7 @@ from spanner_orm import condition from spanner_orm import error from spanner_orm import model +from spanner_orm import registry class Relationship(object): @@ -59,7 +60,8 @@ def conditions(self) -> List[condition.Condition]: @property def destination(self) -> Type[model.Model]: if not self._destination: - self._destination = model.load_model(self._destination_handle) + self._destination = registry.model_registry().get( + self._destination_handle) return self._destination @property @@ -82,5 +84,4 @@ def _parse_constraints(self) -> List[condition.Condition]: conditions.append( condition.ColumnsEqualCondition(destination_column, self.origin, origin_column)) - return conditions diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index bd190db..0027676 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -124,6 +124,9 @@ def test_relation_get_error_on_unretrieved(self): with self.assertRaises(AttributeError): _ = test_model.parent + def test_interleaved(self): + self.assertEqual(models.ChildTestModel.interleaved, models.SmallTestModel) + @mock.patch('spanner_orm.model.ModelMeta.find') def test_reload(self, find): values = {'key': 'key', 'value_1': 'value_1'} diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 4958e00..9951614 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -33,7 +33,7 @@ class ChildTestModel(model.Model): """Model class for testing interleaved tables.""" __table__ = 'ChildTestModel' - __interleaved__ = SmallTestModel + __interleaved__ = 'SmallTestModel' key = field.Field(field.String, primary_key=True) child_key = field.Field(field.String, primary_key=True) From 749851adc710ea3645bcef0f96bb795eb1deaa34 Mon Sep 17 00:00:00 2001 From: Derek Brandao <38337796+dcbrandao@users.noreply.github.com> Date: Wed, 13 Mar 2019 14:09:45 -0400 Subject: [PATCH 2/3] Update relationship.py --- spanner_orm/relationship.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index 5a37909..674257d 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -93,4 +93,5 @@ def _parse_constraints(self) -> List[RelationshipConstraint]: RelationshipConstraint(self.destination, destination_column, self.origin, origin_column)) # type: ignore - return constraints \ No newline at end of file + return constraints + From 3051112ffd8c1e2cf19a9f4f73651a05e16adc3e Mon Sep 17 00:00:00 2001 From: Derek Brandao Date: Wed, 13 Mar 2019 18:10:42 +0000 Subject: [PATCH 3/3] fix newline --- spanner_orm/relationship.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index 674257d..8986502 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -94,4 +94,3 @@ def _parse_constraints(self) -> List[RelationshipConstraint]: self.origin, origin_column)) # type: ignore return constraints -