diff --git a/event_sourcery/event_store/event_store.py b/event_sourcery/event_store/event_store.py index 05ed2bff..a666a38a 100644 --- a/event_sourcery/event_store/event_store.py +++ b/event_sourcery/event_store/event_store.py @@ -25,7 +25,7 @@ def __init__(self, storage_strategy: StorageStrategy, serde: Serde) -> None: self._storage_strategy = storage_strategy self._serde = serde - def load_stream( + async def load_stream( self, stream_id: StreamId, start: int | None = None, @@ -49,11 +49,11 @@ def load_stream( Returns: A sequence of events or empty list if the stream doesn't exist. """ - events = self._storage_strategy.fetch_events(stream_id, start=start, stop=stop) + events = await self._storage_strategy.fetch_events(stream_id, start=start, stop=stop) return self._serde.deserialize_many(events) @singledispatchmethod - def append( + async def append( self, first: WrappedEvent, *events: WrappedEvent, @@ -117,7 +117,7 @@ def _wrap_events_versioning( ) -> Sequence[WrappedEvent]: return [WrappedEvent.wrap(event=event, version=None) for event in events] - def _append( + async def _append( self, stream_id: StreamId, events: Sequence[WrappedEvent], @@ -133,13 +133,13 @@ def _append( else: versioning = NO_VERSIONING - self._storage_strategy.insert_events( + await self._storage_strategy.insert_events( stream_id=stream_id, versioning=versioning, events=self._serde.serialize_many(events, stream_id), ) - def delete_stream(self, stream_id: StreamId) -> None: + async def delete_stream(self, stream_id: StreamId) -> None: """Deletes a stream with a given ID. If a stream does not exist, this method does nothing. @@ -156,9 +156,9 @@ def delete_stream(self, stream_id: StreamId) -> None: Returns: None """ - self._storage_strategy.delete_stream(stream_id) + await self._storage_strategy.delete_stream(stream_id) - def save_snapshot(self, stream_id: StreamId, snapshot: WrappedEvent) -> None: + async def save_snapshot(self, stream_id: StreamId, snapshot: WrappedEvent) -> None: """Saves a snapshot of the stream. Examples: @@ -176,10 +176,10 @@ def save_snapshot(self, stream_id: StreamId, snapshot: WrappedEvent) -> None: """ serialized = self._serde.serialize(event=snapshot, stream_id=stream_id) - self._storage_strategy.save_snapshot(serialized) + await self._storage_strategy.save_snapshot(serialized) @property - def position(self) -> Position | None: + async def position(self) -> Position | None: """Returns the current position of the event store. Examples: @@ -189,7 +189,7 @@ def position(self) -> Position | None: Position(15) # Some events were saved """ - return self._storage_strategy.current_position + return await self._storage_strategy.current_position def scoped_for_tenant(self, tenant_id: TenantId = DEFAULT_TENANT) -> "EventStore": """Factory method to create a new event store instance scoped to a tenant. diff --git a/event_sourcery/event_store/in_memory.py b/event_sourcery/event_store/in_memory.py index 1ae8bbad..46cbea8c 100644 --- a/event_sourcery/event_store/in_memory.py +++ b/event_sourcery/event_store/in_memory.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field, replace from datetime import timedelta from operator import getitem +from typing import AsyncIterator from pydantic import BaseModel, ConfigDict, PositiveInt from typing_extensions import Self @@ -151,9 +152,9 @@ class InMemoryOutboxStorageStrategy(OutboxStorageStrategy): def put_into_outbox(self, records: list[RecordedRaw]) -> None: self._outbox.extend([(e, 0) for e in records if self._filterer(e.entry)]) - def outbox_entries( + async def outbox_entries( self, limit: int - ) -> Iterator[AbstractContextManager[RecordedRaw]]: + ) -> AsyncIterator[AbstractContextManager[RecordedRaw]]: for record in self._outbox[:limit]: yield self._publish_context(*record) @@ -183,7 +184,7 @@ def _reached_max_number_of_attempts(self, failure_count: int) -> bool: class InMemorySubscriptionStrategy(SubscriptionStrategy): _storage: Storage - def subscribe_to_all( + async def subscribe_to_all( self, start_from: Position, batch_size: int, @@ -191,7 +192,7 @@ def subscribe_to_all( ) -> Iterator[list[RecordedRaw]]: return InMemorySubscription(self._storage, start_from, batch_size, timelimit) - def subscribe_to_category( + async def subscribe_to_category( self, start_from: Position, batch_size: int, @@ -206,7 +207,7 @@ def subscribe_to_category( category, ) - def subscribe_to_events( + async def subscribe_to_events( self, start_from: Position, batch_size: int, @@ -236,7 +237,7 @@ def __init__( self._outbox = outbox_strategy self._tenant_id = tenant_id - def fetch_events( + async def fetch_events( self, stream_id: StreamId, start: int | None = None, @@ -250,11 +251,11 @@ def fetch_events( ) return [r.entry for r in stream if r.tenant_id == self._tenant_id] - def insert_events( + async def insert_events( self, stream_id: StreamId, versioning: Versioning, events: list[RawEvent] ) -> None: position = self.current_position or 0 - self._ensure_stream(stream_id=stream_id, versioning=versioning) + await self._ensure_stream(stream_id=stream_id, versioning=versioning) records = [ RecordedRaw(entry=raw, position=position, tenant_id=self._tenant_id) for position, raw in enumerate(events, start=position + 1) @@ -264,7 +265,7 @@ def insert_events( self._outbox.put_into_outbox(records) self._dispatcher.dispatch(*records) - def save_snapshot(self, snapshot: RawEvent) -> None: + async def save_snapshot(self, snapshot: RawEvent) -> None: record = RecordedRaw( entry=snapshot, position=(self.current_position or 0) + 1, @@ -272,7 +273,7 @@ def save_snapshot(self, snapshot: RawEvent) -> None: ) self._storage.replace(with_snapshot=record) - def _ensure_stream(self, stream_id: StreamId, versioning: Versioning) -> None: + async def _ensure_stream(self, stream_id: StreamId, versioning: Versioning) -> None: if stream_id not in self._storage: self._storage.create(stream_id, versioning) @@ -290,12 +291,12 @@ def _ensure_stream(self, stream_id: StreamId, versioning: Versioning) -> None: versioning.expected_version, ) - def delete_stream(self, stream_id: StreamId) -> None: + async def delete_stream(self, stream_id: StreamId) -> None: if stream_id in self._storage: self._storage.delete(stream_id) @property - def current_position(self) -> Position | None: + async def current_position(self) -> Position | None: current_position = self._storage.current_position return current_position and Position(current_position) @@ -387,14 +388,14 @@ class InMemoryKeyStorage(EncryptionKeyStorageStrategy): _keys: dict[tuple[TenantId, str], bytes] = field(default_factory=dict) _tenant_id: TenantId = DEFAULT_TENANT - def get(self, subject_id: str) -> bytes | None: + async def get(self, subject_id: str) -> bytes | None: return self._keys.get((self._tenant_id, subject_id)) - def store(self, subject_id: str, key: bytes) -> None: + async def store(self, subject_id: str, key: bytes) -> None: self._keys[(self._tenant_id, subject_id)] = key - def delete(self, subject_id: str) -> None: + async def delete(self, subject_id: str) -> None: self._keys.pop((self._tenant_id, subject_id), None) - def scoped_for_tenant(self, tenant_id: TenantId) -> Self: + async def scoped_for_tenant(self, tenant_id: TenantId) -> Self: return replace(self, _tenant_id=tenant_id) diff --git a/event_sourcery/event_store/interfaces.py b/event_sourcery/event_store/interfaces.py index 4199f32c..996ccb02 100644 --- a/event_sourcery/event_store/interfaces.py +++ b/event_sourcery/event_store/interfaces.py @@ -2,7 +2,7 @@ from collections.abc import Iterator from contextlib import AbstractContextManager from datetime import timedelta -from typing import Any, Protocol +from typing import Any, Protocol, AsyncIterator from typing_extensions import Self @@ -18,15 +18,15 @@ def __call__(self, entry: RawEvent) -> bool: ... class OutboxStorageStrategy(abc.ABC): @abc.abstractmethod - def outbox_entries( + async def outbox_entries( self, limit: int - ) -> Iterator[AbstractContextManager[RecordedRaw]]: + ) -> AsyncIterator[AbstractContextManager[RecordedRaw]]: pass class SubscriptionStrategy(abc.ABC): @abc.abstractmethod - def subscribe_to_all( + async def subscribe_to_all( self, start_from: Position, batch_size: int, @@ -35,7 +35,7 @@ def subscribe_to_all( pass @abc.abstractmethod - def subscribe_to_category( + async def subscribe_to_category( self, start_from: Position, batch_size: int, @@ -45,7 +45,7 @@ def subscribe_to_category( pass @abc.abstractmethod - def subscribe_to_events( + async def subscribe_to_events( self, start_from: Position, batch_size: int, @@ -57,7 +57,7 @@ def subscribe_to_events( class StorageStrategy(abc.ABC): @abc.abstractmethod - def fetch_events( + async def fetch_events( self, stream_id: StreamId, start: int | None = None, @@ -66,22 +66,22 @@ def fetch_events( pass @abc.abstractmethod - def insert_events( + async def insert_events( self, stream_id: StreamId, versioning: Versioning, events: list[RawEvent] ) -> None: pass @abc.abstractmethod - def save_snapshot(self, snapshot: RawEvent) -> None: + async def save_snapshot(self, snapshot: RawEvent) -> None: pass @abc.abstractmethod - def delete_stream(self, stream_id: StreamId) -> None: + async def delete_stream(self, stream_id: StreamId) -> None: pass @property @abc.abstractmethod - def current_position(self) -> Position | None: + async def current_position(self) -> Position | None: pass @abc.abstractmethod @@ -91,15 +91,15 @@ def scoped_for_tenant(self, tenant_id: str) -> Self: class EncryptionKeyStorageStrategy(abc.ABC): @abc.abstractmethod - def get(self, subject_id: str) -> bytes | None: + async def get(self, subject_id: str) -> bytes | None: pass @abc.abstractmethod - def store(self, subject_id: str, key: bytes) -> None: + async def store(self, subject_id: str, key: bytes) -> None: pass @abc.abstractmethod - def delete(self, subject_id: str) -> None: + async def delete(self, subject_id: str) -> None: pass @abc.abstractmethod @@ -109,9 +109,9 @@ def scoped_for_tenant(self, tenant_id: TenantId) -> Self: class EncryptionStrategy(abc.ABC): @abc.abstractmethod - def encrypt(self, data: Any, key: bytes) -> str: + async def encrypt(self, data: Any, key: bytes) -> str: pass @abc.abstractmethod - def decrypt(self, data: str, key: bytes) -> Any: + async def decrypt(self, data: str, key: bytes) -> Any: pass diff --git a/event_sourcery/event_store/outbox.py b/event_sourcery/event_store/outbox.py index 7973b53c..16d68d7c 100644 --- a/event_sourcery/event_store/outbox.py +++ b/event_sourcery/event_store/outbox.py @@ -9,13 +9,13 @@ def __init__(self, strategy: OutboxStorageStrategy, serde: Serde) -> None: self._strategy = strategy self._serde = serde - def run( + async def run( self, publisher: Callable[[Recorded], None], limit: int = 100, ) -> None: - stream = self._strategy.outbox_entries(limit=limit) - for entry in stream: + stream = await self._strategy.outbox_entries(limit=limit) + async for entry in stream: with entry as raw_record: event = self._serde.deserialize(raw_record.entry) record = Recorded( diff --git a/event_sourcery_kurrentdb/__init__.py b/event_sourcery_kurrentdb/__init__.py index fafa10f0..cf437dd6 100644 --- a/event_sourcery_kurrentdb/__init__.py +++ b/event_sourcery_kurrentdb/__init__.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import TypeAlias -from kurrentdbclient import KurrentDBClient +from kurrentdbclient import AsyncKurrentDBClient from pydantic import BaseModel, ConfigDict, PositiveFloat, PositiveInt from typing_extensions import Self @@ -44,7 +44,7 @@ class Config(BaseModel): @dataclass(repr=False) class KurrentDBBackendFactory(BackendFactory): - kurrentdb_client: KurrentDBClient + kurrentdb_client: AsyncKurrentDBClient config: Config = field(default_factory=Config) _serde: Serde = field(default_factory=lambda: Serde(EventRegistry())) _outbox_strategy: OutboxStorageStrategy = field( diff --git a/event_sourcery_kurrentdb/event_store.py b/event_sourcery_kurrentdb/event_store.py index e3ab48b7..bc978533 100644 --- a/event_sourcery_kurrentdb/event_store.py +++ b/event_sourcery_kurrentdb/event_store.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, replace from typing import cast -from kurrentdbclient import KurrentDBClient, StreamState +from kurrentdbclient import StreamState, AsyncKurrentDBClient from kurrentdbclient.exceptions import NotFoundError from typing_extensions import Self @@ -21,11 +21,11 @@ @dataclass(repr=False) class KurrentDBStorageStrategy(StorageStrategy): - _client: KurrentDBClient + _client: AsyncKurrentDBClient _timeout: float | None _tenant_id: TenantId = DEFAULT_TENANT - def fetch_events( + async def fetch_events( self, stream_id: StreamId, start: int | None = None, @@ -33,47 +33,47 @@ def fetch_events( ) -> list[RawEvent]: snapshot = None name = stream.Name(self._tenant_id, stream_id) - if start is None and (snapshot := self._read_snapshot(name)) is not None: + if start is None and (snapshot := await self._read_snapshot(name)) is not None: start = cast(int, snapshot.version) + 1 position, limit = stream.scope(start, stop) - entries = self._client.read_stream( + entries = await self._client.read_stream( stream_name=str(name), stream_position=position, limit=limit, timeout=self._timeout, ) try: - events = [dto.raw_event(entry) for entry in entries] + events = [dto.raw_event(entry) async for entry in entries] if snapshot: return [snapshot, *events] return events except NotFoundError: return [] - def _read_snapshot(self, name: stream.Name) -> RawEvent | None: - snapshots = self._client.read_stream( + async def _read_snapshot(self, name: stream.Name) -> RawEvent | None: + snapshots = await self._client.read_stream( name.snapshot, limit=1, backwards=True, timeout=self._timeout, ) try: - last = next(iter(snapshots)) + last = await snapshots.__anext__() return dto.snapshot(last) - except NotFoundError: + except StopAsyncIteration: return None - def insert_events( + async def insert_events( self, stream_id: StreamId, versioning: Versioning, events: list[RawEvent] ) -> None: for sid in {e.stream_id for e in events}: - self._ensure_stream(stream_id=sid, versioning=versioning) + await self._ensure_stream(stream_id=sid, versioning=versioning) stream_name = stream.Name(self._tenant_id, sid) stream_events = [e for e in events if e.stream_id == sid] - self._append_events(stream_name, events=stream_events) + await self._append_events(stream_name, events=stream_events) - def _append_events(self, name: stream.Name, events: list[RawEvent]) -> int: + async def _append_events(self, name: stream.Name, events: list[RawEvent]) -> int: return cast( int, self._client.append_events( @@ -84,17 +84,17 @@ def _append_events(self, name: stream.Name, events: list[RawEvent]) -> int: ), ) - def save_snapshot(self, snapshot: RawEvent) -> None: + async def save_snapshot(self, snapshot: RawEvent) -> None: name = stream.Name(self._tenant_id, snapshot.stream_id) stream_position = stream.Position.from_version(cast(int, snapshot.version)) - self._client.append_events( + await self._client.append_events( name.snapshot, current_version=StreamState.ANY, events=[dto.new_entry(snapshot, stream_position=stream_position)], timeout=self._timeout, ) - def _ensure_stream(self, stream_id: StreamId, versioning: Versioning) -> None: + async def _ensure_stream(self, stream_id: StreamId, versioning: Versioning) -> None: name = stream.Name(self._tenant_id, stream_id) if versioning is not NO_VERSIONING and versioning.expected_version: @@ -102,11 +102,11 @@ def _ensure_stream(self, stream_id: StreamId, versioning: Versioning) -> None: if position := self._get_stream_position(name) != expected: raise ConcurrentStreamWriteError(position, expected) - def _get_stream_position(self, name: stream.Name) -> stream.Position | None: + async def _get_stream_position(self, name: stream.Name) -> stream.Position | None: try: last = next( iter( - self._client.get_stream( + await self._client.get_stream( str(name), backwards=True, limit=1, @@ -118,10 +118,10 @@ def _get_stream_position(self, name: stream.Name) -> stream.Position | None: except NotFoundError: return None - def delete_stream(self, stream_id: StreamId) -> None: + async def delete_stream(self, stream_id: StreamId) -> None: name = stream.Name(self._tenant_id, stream_id) try: - self._client.delete_stream( + await self._client.delete_stream( str(name), current_version=StreamState.ANY, timeout=self._timeout, @@ -130,8 +130,8 @@ def delete_stream(self, stream_id: StreamId) -> None: pass @property - def current_position(self) -> Position | None: - return Position(self._client.get_commit_position(timeout=self._timeout)) + async def current_position(self) -> Position | None: + return Position(await self._client.get_commit_position(timeout=self._timeout)) def scoped_for_tenant(self, tenant_id: TenantId) -> Self: return replace(self, _tenant_id=tenant_id) diff --git a/tests/backend/kurrentdb.py b/tests/backend/kurrentdb.py index 0403e966..b31c9bfb 100644 --- a/tests/backend/kurrentdb.py +++ b/tests/backend/kurrentdb.py @@ -1,28 +1,30 @@ from collections.abc import Iterator from contextlib import contextmanager +from typing import AsyncGenerator, AsyncIterator import pytest -from kurrentdbclient import KurrentDBClient, StreamState +from kurrentdbclient import AsyncKurrentDBClient, StreamState from event_sourcery_kurrentdb import KurrentDBBackendFactory @contextmanager -def kurrentdb_client() -> Iterator[KurrentDBClient]: - client = KurrentDBClient(uri="kurrentdb://localhost:2113?Tls=false") +async def kurrentdb_client() -> AsyncGenerator[AsyncKurrentDBClient, None]: + client = AsyncKurrentDBClient(uri="kurrentdb://localhost:2113?Tls=false") commit_position = client.get_commit_position() yield client for event in client._connection.streams.read(commit_position=commit_position): if not event.stream_name.startswith("$"): - client.delete_stream( + await client.delete_stream( event.stream_name, current_version=StreamState.ANY, ) - for sub in client.list_subscriptions(): - client.delete_subscription(sub.group_name) + subscriptions = await client.list_subscriptions() + for sub in subscriptions: + await client.delete_subscription(sub.group_name) @pytest.fixture() -def kurrentdb(request: pytest.FixtureRequest) -> Iterator[KurrentDBBackendFactory]: - with kurrentdb_client() as client: +async def kurrentdb(request: pytest.FixtureRequest) -> AsyncIterator[KurrentDBBackendFactory]: + async with kurrentdb_client() as client: yield KurrentDBBackendFactory(client) diff --git a/tests/event_store/conftest.py b/tests/event_store/conftest.py index 06af00ac..3e2fa92d 100644 --- a/tests/event_store/conftest.py +++ b/tests/event_store/conftest.py @@ -21,11 +21,11 @@ @pytest.fixture( params=[ - django, + # django, kurrentdb, in_memory, - sqlalchemy_sqlite, - sqlalchemy_postgres, + # sqlalchemy_sqlite, + # sqlalchemy_postgres, ] ) def create_backend_factory(