Skip to content

Commit bca1a7b

Browse files
committed
test: add missing async test update
refactor: mypy issues with annotations refactor: simplify async session using SQLmodel to pass mypy use typdef Annotation see https://pydantic-docs.helpmanual.io/usage/schema/#typingannotated-fields
1 parent 886b962 commit bca1a7b

7 files changed

Lines changed: 97 additions & 36 deletions

File tree

docs_src/tutorial_async/connect_async/update_async/tutorial001.py renamed to docs_src/tutorial_async/connect_async/update_async/tutorial001_async.py

File renamed without changes.

docs_src/tutorial_async/select_async/tutorial001_async.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
from typing import Optional
2+
from typing import Annotated, Optional
33

44
from sqlalchemy.ext.asyncio import create_async_engine
5-
from sqlalchemy.orm import sessionmaker
6-
from sqlmodel import Field, Session, SQLModel, select
5+
6+
from sqlmodel import Field, SQLModel, select
77
from sqlmodel.ext.asyncio.session import AsyncSession
88

99

1010
class Hero(SQLModel, table=True):
11-
id: Optional[int] = Field(default=None, primary_key=True)
11+
id: Annotated[int, Field(primary_key=True)]
1212
name: str
1313
secret_name: str
1414
age: Optional[int] = None
@@ -34,8 +34,7 @@ async def create_heroes():
3434
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
3535
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
3636

37-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
38-
async with async_session() as session:
37+
async with AsyncSession(engine) as session:
3938
session.add(hero_1)
4039
session.add(hero_2)
4140
session.add(hero_3)
@@ -44,8 +43,7 @@ async def create_heroes():
4443

4544

4645
async def select_heroes():
47-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
48-
async with async_session() as session:
46+
async with AsyncSession(engine) as session:
4947
statement = select(Hero)
5048
results = await session.exec(statement)
5149
for hero in results:
@@ -59,4 +57,4 @@ async def main():
5957

6058

6159
if __name__ == "__main__":
62-
asyncio.run(main)
60+
asyncio.run(main())

docs_src/tutorial_async/select_async/tutorial002_async.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
from typing import Optional
2+
from typing import Annotated, Optional
33

44
from sqlalchemy.ext.asyncio import create_async_engine
5-
from sqlalchemy.orm import sessionmaker
6-
from sqlmodel import Field, Session, SQLModel, select
5+
6+
from sqlmodel import Field, SQLModel, select
77
from sqlmodel.ext.asyncio.session import AsyncSession # (1)
88

99

1010
class Hero(SQLModel, table=True): # (2)
11-
id: Optional[int] = Field(default=None, primary_key=True)
11+
id: Annotated[int, Field(primary_key=True)]
1212
name: str
1313
secret_name: str
1414
age: Optional[int] = None
@@ -33,8 +33,7 @@ async def create_heroes():
3333
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
3434
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
3535

36-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
37-
async with async_session() as session: # (6)
36+
async with AsyncSession(engine) as session: # (6)
3837
session.add(hero_1)
3938
session.add(hero_2)
4039
session.add(hero_3)
@@ -43,8 +42,7 @@ async def create_heroes():
4342

4443

4544
async def select_heroes():
46-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
47-
async with async_session() as session: # (7)
45+
async with AsyncSession(engine) as session: # (7)
4846
statement = select(Hero) # (8)
4947
results = await session.exec(statement) # (9)
5048
for hero in results: # (10)
@@ -59,4 +57,4 @@ async def main():
5957

6058

6159
if __name__ == "__main__":
62-
asyncio.run(main)
60+
asyncio.run(main())

docs_src/tutorial_async/select_async/tutorial003_async.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
from typing import Optional
2+
from typing import Annotated, Optional
33

44
from sqlalchemy.ext.asyncio import create_async_engine
5-
from sqlalchemy.orm import sessionmaker
6-
from sqlmodel import Field, Session, SQLModel, select
5+
6+
from sqlmodel import Field, SQLModel, select
77
from sqlmodel.ext.asyncio.session import AsyncSession
88

99

1010
class Hero(SQLModel, table=True):
11-
id: Optional[int] = Field(default=None, primary_key=True)
11+
id: Annotated[int, Field(primary_key=True)]
1212
name: str
1313
secret_name: str
1414
age: Optional[int] = None
@@ -33,8 +33,7 @@ async def create_heroes():
3333
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
3434
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
3535

36-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
37-
async with async_session() as session:
36+
async with AsyncSession(engine) as session:
3837
session.add(hero_1)
3938
session.add(hero_2)
4039
session.add(hero_3)
@@ -43,8 +42,7 @@ async def create_heroes():
4342

4443

4544
async def select_heroes():
46-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
47-
async with async_session() as session:
45+
async with AsyncSession(engine) as session:
4846
statement = select(Hero)
4947
results = await session.exec(statement)
5048
heroes = results.all()
@@ -58,4 +56,4 @@ async def main():
5856

5957

6058
if __name__ == "__main__":
61-
asyncio.run(main)
59+
asyncio.run(main())

docs_src/tutorial_async/select_async/tutorial004_async.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
from typing import Optional
2+
from typing import Annotated, Optional
33

44
from sqlalchemy.ext.asyncio import create_async_engine
5-
from sqlalchemy.orm import sessionmaker
6-
from sqlmodel import Field, Session, SQLModel, select
5+
6+
from sqlmodel import Field, SQLModel, select
77
from sqlmodel.ext.asyncio.session import AsyncSession
88

99

1010
class Hero(SQLModel, table=True):
11-
id: Optional[int] = Field(default=None, primary_key=True)
11+
id: Annotated[int, Field(primary_key=True)]
1212
name: str
1313
secret_name: str
1414
age: Optional[int] = None
@@ -33,8 +33,7 @@ async def create_heroes():
3333
hero_2 = Hero(name="Spider-Boy", secret_name="Pedro Parqueador")
3434
hero_3 = Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)
3535

36-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
37-
async with async_session() as session:
36+
async with AsyncSession(engine) as session:
3837
session.add(hero_1)
3938
session.add(hero_2)
4039
session.add(hero_3)
@@ -43,8 +42,7 @@ async def create_heroes():
4342

4443

4544
async def select_heroes():
46-
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
47-
async with async_session() as session:
45+
async with AsyncSession(engine) as session:
4846
# TODO: in async, this does not work `await session.exec(select(Hero)).all()`
4947
# heroes = await session.exec(select(Hero)).all()
5048
results = await session.exec(select(Hero))
@@ -59,4 +57,4 @@ async def main():
5957

6058

6159
if __name__ == "__main__":
62-
asyncio.run(main)
60+
asyncio.run(main())

tests/test_async_tutorial/test_async_connect/test_async_update/__init__.py

Whitespace-only changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
from sqlalchemy.ext.asyncio import create_async_engine
5+
from sqlalchemy.orm import sessionmaker
6+
from sqlmodel import Session, SQLModel, select
7+
from sqlmodel.ext.asyncio.session import AsyncSession
8+
9+
from tests.conftest import get_testing_print_function
10+
11+
expected_calls = [
12+
[
13+
"Created hero:",
14+
{
15+
"age": None,
16+
"id": 1,
17+
"secret_name": "Dive Wilson",
18+
"team_id": 2,
19+
"name": "Deadpond",
20+
},
21+
],
22+
[
23+
"Created hero:",
24+
{
25+
"age": 48,
26+
"id": 2,
27+
"secret_name": "Tommy Sharp",
28+
"team_id": 1,
29+
"name": "Rusty-Man",
30+
},
31+
],
32+
[
33+
"Created hero:",
34+
{
35+
"age": None,
36+
"id": 3,
37+
"secret_name": "Pedro Parqueador",
38+
"team_id": None,
39+
"name": "Spider-Boy",
40+
},
41+
],
42+
[
43+
"Updated hero:",
44+
{
45+
"age": None,
46+
"id": 3,
47+
"secret_name": "Pedro Parqueador",
48+
"team_id": 1,
49+
"name": "Spider-Boy",
50+
},
51+
],
52+
]
53+
54+
55+
@pytest.mark.asyncio()
56+
async def test_tutorial(clear_sqlmodel):
57+
from docs_src.tutorial_async.connect_async.update_async import (
58+
tutorial001_async as mod,
59+
)
60+
61+
mod.sqlite_url = "sqlite+aiosqlite://"
62+
mod.engine = create_async_engine(mod.sqlite_url)
63+
calls = []
64+
65+
new_print = get_testing_print_function(calls)
66+
67+
with patch("builtins.print", new=new_print):
68+
await mod.main()
69+
assert calls == expected_calls

0 commit comments

Comments
 (0)