forked from tortoise/tortoise-orm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_prefetching.py
More file actions
152 lines (136 loc) · 7.2 KB
/
test_prefetching.py
File metadata and controls
152 lines (136 loc) · 7.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from tests.testmodels import Address, Event, Team, Tournament
from tortoise.contrib import test
from tortoise.exceptions import FieldError, OperationalError
from tortoise.functions import Count
from tortoise.query_utils import Prefetch
class TestPrefetching(test.TestCase):
async def test_prefetch(self):
tournament = await Tournament.create(name="tournament")
event = await Event.create(name="First", tournament=tournament)
await Event.create(name="Second", tournament=tournament)
team = await Team.create(name="1")
team_second = await Team.create(name="2")
await event.participants.add(team, team_second)
tournament = await Tournament.all().prefetch_related("events__participants").first()
self.assertEqual(len(tournament.events[0].participants), 2)
self.assertEqual(len(tournament.events[1].participants), 0)
async def test_prefetch_object(self):
tournament = await Tournament.create(name="tournament")
await Event.create(name="First", tournament=tournament)
await Event.create(name="Second", tournament=tournament)
tournament_with_filtered = (
await Tournament.all()
.prefetch_related(Prefetch("events", queryset=Event.filter(name="First")))
.first()
)
tournament = await Tournament.first().prefetch_related("events")
self.assertEqual(len(tournament_with_filtered.events), 1)
self.assertEqual(len(tournament.events), 2)
async def test_prefetch_unknown_field(self):
with self.assertRaises(OperationalError):
tournament = await Tournament.create(name="tournament")
await Event.create(name="First", tournament=tournament)
await Event.create(name="Second", tournament=tournament)
await Tournament.all().prefetch_related(
Prefetch("events1", queryset=Event.filter(name="First"))
).first()
async def test_prefetch_m2m(self):
tournament = await Tournament.create(name="tournament")
event = await Event.create(name="First", tournament=tournament)
team = await Team.create(name="1")
team_second = await Team.create(name="2")
await event.participants.add(team, team_second)
fetched_events = (
await Event.all()
.prefetch_related(Prefetch("participants", queryset=Team.filter(name="1")))
.first()
)
self.assertEqual(len(fetched_events.participants), 1)
async def test_prefetch_o2o(self):
tournament = await Tournament.create(name="tournament")
event = await Event.create(name="First", tournament=tournament)
await Address.create(city="Santa Monica", street="Ocean", event=event)
fetched_events = await Event.all().prefetch_related("address").first()
self.assertEqual(fetched_events.address.city, "Santa Monica")
async def test_prefetch_nested(self):
tournament = await Tournament.create(name="tournament")
event = await Event.create(name="First", tournament=tournament)
await Event.create(name="Second", tournament=tournament)
team = await Team.create(name="1")
team_second = await Team.create(name="2")
await event.participants.add(team, team_second)
fetched_tournaments = (
await Tournament.all()
.prefetch_related(
Prefetch("events", queryset=Event.filter(name="First")),
Prefetch("events__participants", queryset=Team.filter(name="1")),
)
.first()
)
self.assertEqual(len(fetched_tournaments.events[0].participants), 1)
async def test_prefetch_nested_with_aggregation(self):
tournament = await Tournament.create(name="tournament")
event = await Event.create(name="First", tournament=tournament)
await Event.create(name="Second", tournament=tournament)
team = await Team.create(name="1")
team_second = await Team.create(name="2")
await event.participants.add(team, team_second)
fetched_tournaments = (
await Tournament.all()
.prefetch_related(
Prefetch(
"events", queryset=Event.annotate(teams=Count("participants")).filter(teams=2)
)
)
.first()
)
self.assertEqual(len(fetched_tournaments.events), 1)
self.assertEqual(fetched_tournaments.events[0].pk, event.pk)
async def test_prefetch_direct_relation(self):
tournament = await Tournament.create(name="tournament")
await Event.create(name="First", tournament=tournament)
event = await Event.first().prefetch_related("tournament")
self.assertEqual(event.tournament.id, tournament.id)
async def test_prefetch_bad_key(self):
tournament = await Tournament.create(name="tournament")
await Event.create(name="First", tournament=tournament)
with self.assertRaisesRegex(FieldError, "Relation tour1nament for models.Event not found"):
await Event.first().prefetch_related("tour1nament")
async def test_prefetch_m2m_filter(self):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
team_second = await Team.create(name="2")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team, team_second)
event = await Event.first().prefetch_related(
Prefetch("participants", Team.filter(name="2"))
)
self.assertEqual(len(event.participants), 1)
self.assertEqual(list(event.participants), [team_second])
async def test_prefetch_m2m_to_attr(self):
tournament = await Tournament.create(name="tournament")
team = await Team.create(name="1")
team_second = await Team.create(name="2")
event = await Event.create(name="First", tournament=tournament)
await event.participants.add(team, team_second)
event = await Event.first().prefetch_related(
Prefetch("participants", Team.filter(name="1"), to_attr="to_attr_participants_1"),
Prefetch("participants", Team.filter(name="2"), to_attr="to_attr_participants_2"),
)
self.assertEqual(list(event.to_attr_participants_1), [team])
self.assertEqual(list(event.to_attr_participants_2), [team_second])
async def test_prefetch_o2o_to_attr(self):
tournament = await Tournament.create(name="tournament")
event = await Event.create(name="First", tournament=tournament)
address = await Address.create(city="Santa Monica", street="Ocean", event=event)
event = await Event.get(pk=event.pk).prefetch_related(
Prefetch("address", to_attr="to_address", queryset=Address.all())
)
self.assertEqual(address.pk, event.to_address.pk)
async def test_prefetch_direct_relation_to_attr(self):
tournament = await Tournament.create(name="tournament")
await Event.create(name="First", tournament=tournament)
event = await Event.first().prefetch_related(
Prefetch("tournament", queryset=Tournament.all(), to_attr="to_attr_tournament")
)
self.assertEqual(event.to_attr_tournament.id, tournament.id)