Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions spanner_orm/admin/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
class SpannerMetadata(object):
"""Gathers information about a table from Spanner."""

@classmethod
def _class_name_from_table(cls, table_name):
return 'table_{}_model'.format(table_name)

@classmethod
def models(cls):
"""Constructs model classes from Spanner table schema."""
Expand All @@ -39,24 +43,20 @@ def models(cls):
for table_name, table_data in tables.items():
primary_index = indexes[table_name][index.Index.PRIMARY_INDEX]
primary_keys = set(primary_index.columns)
klass = model.ModelBase('Model_{}'.format(table_name), (model.Model,),
{})
klass = model.ModelBase(
cls._class_name_from_table(table_name), (model.Model,), {})
for model_field in table_data['fields'].values():
model_field._primary_key = model_field.name in primary_keys # pylint: disable=protected-access

klass.meta = model.Metadata(
table=table_name,
fields=table_data['fields'],
interleaved=table_data['parent_table'],
interleaved=cls._class_name_from_table(table_data['parent_table']),
indexes=indexes[table_name],
model_class=klass)
klass.meta.finalize()
models[table_name] = klass

for table_model in models.values():
if table_model.meta.interleaved:
table_model.meta.interleaved = models[table_model.meta.interleaved]
table_model.meta.finalize()

return models

@classmethod
Expand All @@ -67,9 +67,9 @@ def model(cls, table_name):
def tables(cls):
"""Compiles table information from column schema."""
column_data = collections.defaultdict(dict)
columns = column.ColumnSchema.where(
None, condition.equal_to('table_catalog', ''),
condition.equal_to('table_schema', ''))
columns = column.ColumnSchema.where(None,
condition.equal_to('table_catalog', ''),
condition.equal_to('table_schema', ''))
for column_row in columns:
new_field = field.Field(
column_row.field_type(), nullable=column_row.nullable())
Expand All @@ -78,9 +78,9 @@ def tables(cls):
column_data[column_row.table_name][column_row.column_name] = new_field

table_data = collections.defaultdict(dict)
tables = table.TableSchema.where(
None, condition.equal_to('table_catalog', ''),
condition.equal_to('table_schema', ''))
tables = table.TableSchema.where(None,
condition.equal_to('table_catalog', ''),
condition.equal_to('table_schema', ''))
for table_row in tables:
name = table_row.table_name
table_data[name]['parent_table'] = table_row.parent_table_name
Expand Down
22 changes: 5 additions & 17 deletions spanner_orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

import collections
import copy
import importlib

from spanner_orm import api
from spanner_orm import condition
from spanner_orm import error
from spanner_orm import field
from spanner_orm import index
from spanner_orm import query
from spanner_orm import registry
from spanner_orm import relationship

from google.cloud import spanner
Expand Down Expand Up @@ -69,6 +69,7 @@ def finalize(self):

for _, relation in self.relations.items():
relation.origin = self.model_class
registry.model_registry().register(self.model_class)
self._finalized = True

def add_metadata(self, metadata):
Expand Down Expand Up @@ -153,9 +154,9 @@ def indexes(cls):

@property
def interleaved(cls):
if cls.meta.interleaved and not isinstance(cls.meta.interleaved, ModelBase):
cls.meta.interleaved = load_model(cls.meta.interleaved)
return cls.meta.interleaved
if cls.meta.interleaved:
return registry.model_registry().get(cls.meta.interleaved)
return None

@property
def primary_keys(cls):
Expand Down Expand Up @@ -426,16 +427,3 @@ def save(self, transaction=None):
self._metaclass.create(transaction, **self.values)
self._persisted = True
return self


def load_model(model_handle):
if isinstance(model_handle, Model):
return model_handle
parts = model_handle.split('.')
path = '.'.join(parts[:-1])
module = importlib.import_module(path)
klass = getattr(module, parts[-1])
if not issubclass(klass, Model):
raise error.SpannerError(
'{model} is not a Model'.format(model=model_handle))
return klass
67 changes: 67 additions & 0 deletions spanner_orm/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registers Model classes so they can be referenced elsewhere."""

from __future__ import annotations

from typing import Any, Dict, List, Type, Union

import dataclasses
from spanner_orm import error


@dataclasses.dataclass
class RegistryComponent:
references: List[Type[Any]] = dataclasses.field(default_factory=list)

def add(self, reference: Type[Any]) -> None:
self.references.append(reference)


class Registry(object):

def __init__(self):
self._registered = {} # type: Dict[str, RegistryComponent]

def _name_from_class(self, klass: Type[Any]) -> str:
return '{}.{}'.format(klass.__module__, klass.__name__)

def register(self, to_register: Type[Any]) -> None:
name_components = reversed(self._name_from_class(to_register).split('.'))
name = None
for component in name_components:
name = name = '{}.{}'.format(component, name) if name else component
if name not in self._registered:
self._registered[name] = RegistryComponent()
self._registered[name].add(to_register)

def get(self, name: Union[Type[Any], str]) -> Type[Any]:
if isinstance(name, type):
name = self._name_from_class(name)

if name not in self._registered:
raise error.SpannerError(
'{} was not found, verify it has been imported'.format(name))
if len(self._registered[name].references) > 1:
raise error.SpannerError(
'Multiple classes match {}, add more specificity'.format(name))
return self._registered[name].references[0]


_registry = Registry()


def model_registry():
return _registry
4 changes: 3 additions & 1 deletion spanner_orm/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dataclasses
from spanner_orm import error
from spanner_orm import model
from spanner_orm import registry


@dataclasses.dataclass
Expand Down Expand Up @@ -67,7 +68,8 @@ def constraints(self) -> List[RelationshipConstraint]:
@property
def destination(self) -> Type[model.Model]:
if not self._destination:
self._destination = model.load_model(self._destination_handle)
self._destination = registry.model_registry().get(
self._destination_handle)
return self._destination

@property
Expand Down
3 changes: 3 additions & 0 deletions spanner_orm/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def test_relation_get_error_on_unretrieved(self):
with self.assertRaises(AttributeError):
_ = test_model.parent

def test_interleaved(self):
self.assertEqual(models.ChildTestModel.interleaved, models.SmallTestModel)

@mock.patch('spanner_orm.model.ModelMeta.find')
def test_reload(self, find):
values = {'key': 'key', 'value_1': 'value_1'}
Expand Down
2 changes: 1 addition & 1 deletion spanner_orm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ChildTestModel(model.Model):
"""Model class for testing interleaved tables."""

__table__ = 'ChildTestModel'
__interleaved__ = SmallTestModel
__interleaved__ = 'SmallTestModel'

key = field.Field(field.String, primary_key=True)
child_key = field.Field(field.String, primary_key=True)
Expand Down