|
19 | 19 | # |
20 | 20 | # ------------------------------------------------------------- |
21 | 21 | from typing import Dict, List, Tuple, Any, Optional |
22 | | -import numpy as np |
23 | | -from sklearn.model_selection import ParameterGrid |
| 22 | +from skopt import gp_minimize |
| 23 | +from skopt.space import Real, Integer, Categorical |
| 24 | +from skopt.utils import use_named_args |
24 | 25 | import json |
25 | 26 | import logging |
26 | 27 | from dataclasses import dataclass |
27 | 28 | import time |
28 | 29 | import copy |
29 | 30 |
|
30 | 31 | from systemds.scuro.modality.modality import Modality |
31 | | -from systemds.scuro.drsearch.task import Task |
32 | 32 |
|
33 | 33 |
|
34 | 34 | @dataclass |
@@ -163,18 +163,64 @@ def visit_node(node_id): |
163 | 163 | start_time = time.time() |
164 | 164 | rep_name = "_".join([rep.__name__ for rep in reps]) |
165 | 165 |
|
166 | | - param_grid = list(ParameterGrid(hyperparams)) |
167 | | - if max_evals and len(param_grid) > max_evals: |
168 | | - np.random.shuffle(param_grid) |
169 | | - param_grid = param_grid[:max_evals] |
| 166 | + search_space = [] |
| 167 | + param_names = [] |
| 168 | + for param_name, param_values in hyperparams.items(): |
| 169 | + param_names.append(param_name) |
| 170 | + if isinstance(param_values, list): |
| 171 | + if all(isinstance(v, (int, float)) for v in param_values): |
| 172 | + if all(isinstance(v, int) for v in param_values): |
| 173 | + search_space.append( |
| 174 | + Integer( |
| 175 | + min(param_values), max(param_values), name=param_name |
| 176 | + ) |
| 177 | + ) |
| 178 | + else: |
| 179 | + search_space.append( |
| 180 | + Real(min(param_values), max(param_values), name=param_name) |
| 181 | + ) |
| 182 | + else: |
| 183 | + search_space.append(Categorical(param_values, name=param_name)) |
| 184 | + elif isinstance(param_values, tuple) and len(param_values) == 2: |
| 185 | + if isinstance(param_values[0], int) and isinstance( |
| 186 | + param_values[1], int |
| 187 | + ): |
| 188 | + search_space.append( |
| 189 | + Integer(param_values[0], param_values[1], name=param_name) |
| 190 | + ) |
| 191 | + else: |
| 192 | + search_space.append( |
| 193 | + Real(param_values[0], param_values[1], name=param_name) |
| 194 | + ) |
| 195 | + else: |
| 196 | + search_space.append(Categorical([param_values], name=param_name)) |
| 197 | + |
| 198 | + n_calls = max_evals if max_evals else 50 |
170 | 199 |
|
171 | 200 | all_results = [] |
172 | | - for params in param_grid: |
| 201 | + |
| 202 | + @use_named_args(search_space) |
| 203 | + def objective(**params): |
173 | 204 | result = self.evaluate_dag_config( |
174 | 205 | dag, params, node_order, modality_ids, task |
175 | 206 | ) |
176 | 207 | all_results.append(result) |
177 | 208 |
|
| 209 | + score = result[1].average_scores[self.scoring_metric] |
| 210 | + if self.maximize_metric: |
| 211 | + return -score |
| 212 | + else: |
| 213 | + return score |
| 214 | + |
| 215 | + result = gp_minimize( |
| 216 | + objective, |
| 217 | + search_space, |
| 218 | + n_calls=n_calls, |
| 219 | + random_state=42, |
| 220 | + verbose=self.debug, |
| 221 | + n_initial_points=min(10, n_calls // 2), |
| 222 | + ) |
| 223 | + |
178 | 224 | if self.maximize_metric: |
179 | 225 | best_params, best_score = max( |
180 | 226 | all_results, key=lambda x: x[1].average_scores[self.scoring_metric] |
|
0 commit comments