-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdownsample_mimic.py
More file actions
55 lines (48 loc) · 1.41 KB
/
downsample_mimic.py
File metadata and controls
55 lines (48 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import annotations
# fmt: off
import sys # isort: skip
from pathlib import Path # isort: skip
ROOT = Path(__file__).resolve().parent.parent # isort: skip
sys.path.append(str(ROOT)) # isort: skip
# fmt: on
import os
import sys
from argparse import Namespace
from pathlib import Path
from sklearn.model_selection import ParameterGrid
from src.enumerables import ClassifierKind, Dataset
from src.prediction import evaluate_downsampling
if __name__ == "__main__":
MAX_WORKERS = 80 if os.environ.get("CC_CLUSTER") == "niagara" else 8
grid = [
Namespace(**args)
for args in list(
ParameterGrid(
{
"dataset": [Dataset.MimicIV],
"kind": [ClassifierKind.SVM],
},
)
)
]
print(f"Total number of combinations: {len(grid)}")
idx = os.environ.get("SLURM_ARRAY_TASK_ID")
if idx is None:
for args in grid:
evaluate_downsampling(
classifier=args.kind,
dataset=args.dataset,
downsample=True,
n_reps=200,
n_runs=10,
)
else:
args = grid[int(idx)]
evaluate_downsampling(
classifier=args.kind,
dataset=args.dataset,
downsample=True,
n_reps=200,
n_runs=10,
max_workers=MAX_WORKERS,
)