Skip to content

Commit 6b2e8c6

Browse files
committed
get_by_batchref in abstract+real repo, and a new integration test for it [get_by_batchref]
1 parent eb11e04 commit 6b2e8c6

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/allocation/repository.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Set
22
import abc
3-
from allocation import model
3+
from allocation import model, orm
4+
45

56

67
class AbstractRepository(abc.ABC):
@@ -18,6 +19,12 @@ def get(self, sku) -> model.Product:
1819
self.seen.add(product)
1920
return product
2021

22+
def get_by_batchref(self, batchref) -> model.Product:
23+
product = self._get_by_batchref(batchref)
24+
if product:
25+
self.seen.add(product)
26+
return product
27+
2128
@abc.abstractmethod
2229
def _add(self, product: model.Product):
2330
raise NotImplementedError
@@ -26,6 +33,11 @@ def _add(self, product: model.Product):
2633
def _get(self, sku) -> model.Product:
2734
raise NotImplementedError
2835

36+
@abc.abstractmethod
37+
def _get_by_batchref(self, batchref) -> model.Product:
38+
raise NotImplementedError
39+
40+
2941

3042

3143
class SqlAlchemyRepository(AbstractRepository):
@@ -39,3 +51,8 @@ def _add(self, product):
3951

4052
def _get(self, sku):
4153
return self.session.query(model.Product).filter_by(sku=sku).first()
54+
55+
def _get_by_batchref(self, batchref):
56+
return self.session.query(model.Product).join(model.Batch).filter(
57+
orm.batches.c.reference == batchref,
58+
).first()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from allocation import model, repository
2+
3+
def test_get_by_batchref(session):
4+
repo = repository.SqlAlchemyRepository(session)
5+
b1 = model.Batch(ref='b1', sku='sku1', qty=100, eta=None)
6+
b2 = model.Batch(ref='b2', sku='sku1', qty=100, eta=None)
7+
b3 = model.Batch(ref='b3', sku='sku2', qty=100, eta=None)
8+
p1 = model.Product(sku='sku1', batches=[b1, b2])
9+
p2 = model.Product(sku='sku2', batches=[b3])
10+
repo.add(p1)
11+
repo.add(p2)
12+
assert repo.get_by_batchref('b2') == p1
13+
assert repo.get_by_batchref('b3') == p2
14+

0 commit comments

Comments
 (0)