Skip to content

Commit 803d9b6

Browse files
committed
implement .seen on repository [repository_tracks_seen]
1 parent bfdd979 commit 803d9b6

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

src/allocation/repository.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,41 @@
1-
import abc
21
from typing import Set
2+
import abc
33
from allocation import model
44

55

66
class AbstractRepository(abc.ABC):
77

8-
@abc.abstractmethod
8+
def __init__(self):
9+
self.seen = set() # type: Set[model.Product]
10+
911
def add(self, product: model.Product):
12+
self._add(product)
13+
self.seen.add(product)
14+
15+
def get(self, sku) -> model.Product:
16+
product = self._get(sku)
17+
if product:
18+
self.seen.add(product)
19+
return product
20+
21+
@abc.abstractmethod
22+
def _add(self, product: model.Product):
1023
raise NotImplementedError
1124

1225
@abc.abstractmethod
13-
def get(self, sku) -> model.Product:
26+
def _get(self, sku) -> model.Product:
1427
raise NotImplementedError
1528

1629

1730

1831
class SqlAlchemyRepository(AbstractRepository):
1932

2033
def __init__(self, session):
34+
super().__init__()
2135
self.session = session
22-
self.seen = set() # type: Set[model.Product]
2336

24-
def add(self, product):
25-
self.seen.add(product)
37+
def _add(self, product):
2638
self.session.add(product)
2739

28-
def get(self, sku):
29-
product = self.session.query(model.Product).filter_by(sku=sku).first()
30-
if product:
31-
self.seen.add(product)
32-
return product
40+
def _get(self, sku):
41+
return self.session.query(model.Product).filter_by(sku=sku).first()

tests/unit/test_services.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
class FakeRepository(repository.AbstractRepository):
77

88
def __init__(self, products):
9+
super().__init__()
910
self._products = set(products)
1011

11-
def add(self, product):
12+
def _add(self, product):
1213
self._products.add(product)
1314

14-
def get(self, sku):
15+
def _get(self, sku):
1516
return next((p for p in self._products if p.sku == sku), None)
1617

1718

0 commit comments

Comments
 (0)