-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
310 lines (242 loc) · 8.96 KB
/
data.py
File metadata and controls
310 lines (242 loc) · 8.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import io
import json
from collections import defaultdict
from copy import deepcopy
from functools import cache
from itertools import chain, count
from django.apps import apps
from django.conf import settings
from django.core import serializers
from django.core.management.color import no_style
from django.db import DEFAULT_DB_ALIAS, connections, transaction
from django.utils.crypto import get_random_string
from django.utils.module_loading import import_string
from feincms3_data.serializers import JSONEncoder, JSONSerializer
def datasets():
return import_string(settings.FEINCMS3_DATA_DATASETS)()
def _all_subclasses(cls):
for sc in cls.__subclasses__():
yield sc
yield from _all_subclasses(sc)
def _only_concrete_models(iterable):
for model in iterable:
if not model._meta.abstract and not model._meta.proxy:
yield model
def _random_values():
"""Generate a stream of values which are unlikely to cause conflicts"""
prefix = get_random_string(20)
for i in count():
yield f"{prefix}-{i}"
class InvalidVersionError(Exception):
pass
class InvalidSpecError(Exception):
pass
_valid_keys = {
"model",
"filter",
# Flags:
"delete_missing",
"ignore_missing_m2m",
"save_as_new",
"defer_values",
}
def _validate_spec(spec):
if "model" not in spec:
raise InvalidSpecError(f"The spec {spec!r} requires a 'model' key")
if unknown := (set(spec.keys()) - _valid_keys):
raise InvalidSpecError(f"The spec {spec!r} contains unknown keys: {unknown!r}")
return spec
def specs_for_models(models, spec=None):
spec = {} if spec is None else spec
return [_validate_spec({**spec, "model": cls._meta.label_lower}) for cls in models]
def specs_for_derived_models(cls, spec=None):
return specs_for_models(_only_concrete_models(_all_subclasses(cls)), spec)
def specs_for_app_models(app, spec=None):
return specs_for_models(apps.get_app_config(app).get_models(), spec)
def _model_queryset(spec):
queryset = apps.get_model(spec["model"])._default_manager.order_by("pk")
if f := spec.get("filter"):
queryset = queryset.filter(**f)
return queryset
def silence(*a):
pass
def dump_specs(specs, *, mappers=None, objects=None):
stream = io.StringIO()
stream.write('{"version": 1, "specs": ')
json.dump(specs, stream, cls=JSONEncoder)
stream.write(', "objects": ')
serializer = JSONSerializer(mappers=mappers or {})
if objects is None:
objects = chain.from_iterable(
_model_queryset(spec).distinct() for spec in specs
)
serializer.serialize(objects, stream=stream)
return stream.getvalue().rstrip("\n") + "}\n"
def load_dump(
data, *, progress=silence, ignorenonexistent=False, using=DEFAULT_DB_ALIAS
):
if data["version"] != 1:
raise InvalidVersionError(f"Invalid dump version {data.get('version')!r}")
for spec in data["specs"]:
_validate_spec(spec)
objects = defaultdict(list)
seen_pks = defaultdict(set)
# Yes, that is a bit stupid
for ds in serializers.deserialize(
"json",
json.dumps(data["objects"]),
ignorenonexistent=ignorenonexistent,
# handle_forward_references=True,
):
objects[ds.object._meta.label_lower].append(ds)
progress(f"Loaded {len(data['objects'])} objects")
save_as_new_models = {
spec["model"] for spec in data["specs"] if spec.get("save_as_new")
}
with transaction.atomic(using=using):
connection = connections[using]
with connection.constraint_checks_disabled():
models = set()
_load_dump(
data,
objects,
progress,
seen_pks,
save_as_new_models,
models,
)
_finalize(
progress,
connection,
models,
)
def _load_dump(
data,
objects,
progress,
seen_pks,
save_as_new_models,
models,
):
save_as_new_pk_map = defaultdict(dict)
ignore_missing_m2m_data = defaultdict(dict)
deferred_new_pks = []
deferred_values = []
deferred_m2m = []
for spec in data["specs"]:
if objs := objects[spec["model"]]:
for ds in objs:
for field_name in spec.get("ignore_missing_m2m", ()):
ignore_missing_m2m_data[ds][field_name] = ds.m2m_data.pop(
field_name, []
)
random_value = _random_values()
for field_name in spec.get("defer_values", ()):
deferred_values.append(
(ds, field_name, getattr(ds.object, field_name))
)
setattr(ds.object, field_name, next(random_value))
_do_save(
ds,
pk_map=save_as_new_pk_map,
save_as_new_models=save_as_new_models,
deferred_new_pks=deferred_new_pks,
deferred_m2m=deferred_m2m,
)
seen_pks[ds.object._meta.label_lower].add(ds.object.pk)
models.add(ds.object.__class__)
progress(f"Saved {len(objs)} {spec['model']} objects")
_save_deferred_new_pks(deferred_new_pks)
_save_deferred_m2m(deferred_m2m)
for spec in reversed(data["specs"]):
if not spec.get("delete_missing"):
continue
if isinstance(spec["delete_missing"], dict) and (
map := spec["delete_missing"].get("map")
):
queryset = _model_queryset(_map_spec(spec, map, save_as_new_pk_map))
else:
queryset = _model_queryset(spec)
deleted = queryset.exclude(pk__in=seen_pks[spec["model"]]).delete()
if deleted[0]:
progress(f"Deleted {spec['model']} objects: {deleted}")
pks = pk_cache()
for ds, lists in ignore_missing_m2m_data.items():
for field_name, field_pks in lists.items():
field = ds.object._meta.get_field(field_name)
existing = pks(field.related_model)
getattr(ds.object, field_name).set(set(field_pks) & existing)
for ds, field_name, value in deferred_values:
setattr(ds.object, field_name, value)
ds.save()
def _map_spec(spec, map, save_as_new_pk_map):
spec = deepcopy(spec)
for key, model in map:
cls = apps.get_model(model)
if isinstance(spec["filter"][key], (list, tuple)):
spec["filter"][key] = [
save_as_new_pk_map[cls][pk] for pk in spec["filter"][key]
]
else:
spec["filter"][key] = save_as_new_pk_map[cls][spec["filter"][key]]
return spec
def _save_deferred_new_pks(deferred_new_pks):
for ds, f_name, pk_map, fk in deferred_new_pks:
setattr(ds.object, f_name, pk_map[fk])
ds.save()
def _save_deferred_m2m(deferred_m2m):
for ds, m2m_data, f_name, pk_map in deferred_m2m:
if pks := m2m_data.get(f_name):
getattr(ds, f_name).set([pk_map[pk] for pk in pks])
def _finalize(
progress,
connection,
models,
):
table_names = [model._meta.db_table for model in models]
try:
connection.check_constraints(table_names=table_names)
except Exception as e:
e.args = ("Problem installing fixtures: %s" % e,)
raise
sequence_sql = connection.ops.sequence_reset_sql(no_style(), models)
if sequence_sql:
progress("Resetting sequences")
with connection.cursor() as cursor:
for line in sequence_sql:
cursor.execute(line)
def pk_cache():
@cache
def pks(model):
return set(model._default_manager.values_list("pk", flat=True))
return pks
_sentinel = object()
def _do_save(ds, *, pk_map, save_as_new_models, deferred_new_pks, deferred_m2m):
# Map old PKs to new
for f in ds.object._meta.get_fields():
if f.many_to_many and f.related_model._meta.label_lower in save_as_new_models:
# Always defer
deferred_m2m.append(
(ds.object, ds.m2m_data.copy(), f.name, pk_map[f.related_model])
)
elif (
f.concrete
and f.related_model
and f.related_model._meta.label_lower in save_as_new_models
and (fk := getattr(ds.object, f.column)) is not None
):
if (new_pk := pk_map[f.related_model].get(fk, _sentinel)) is not _sentinel:
setattr(ds.object, f.name, new_pk)
else:
# If foreign key isn't nullable we're toast.
setattr(ds.object, f.name, None)
# But if it is, we can defer.
deferred_new_pks.append((ds, f.name, pk_map[f.related_model], fk))
if ds.object._meta.label_lower in save_as_new_models:
# Do the saving
old_pk = ds.object.pk
ds.object.pk = None
ds.save(force_insert=True)
pk_map[ds.object.__class__][old_pk] = ds.object
else:
ds.save()