Skip to content

Commit 15dbd5d

Browse files
committed
Create branching.py
1 parent fb1e5c9 commit 15dbd5d

File tree

1 file changed

+283
-0
lines changed

1 file changed

+283
-0
lines changed

modelsync/experiments/branching.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""
2+
Experiment branching and comparison for ModelSync
3+
"""
4+
5+
import json
6+
import shutil
7+
from pathlib import Path
8+
from typing import Dict, List, Optional, Any, Tuple
9+
from datetime import datetime
10+
from modelsync.utils.helpers import ensure_directory, write_json_file, read_json_file
11+
12+
class ExperimentBranch:
13+
"""Represents an experiment branch"""
14+
15+
def __init__(self, name: str, base_branch: Optional[str] = None, repo_path: str = "."):
16+
self.name = name
17+
self.base_branch = base_branch
18+
self.repo_path = Path(repo_path)
19+
self.branch_dir = self.repo_path / ".modelsync" / "branches" / name
20+
self.experiments_dir = self.branch_dir / "experiments"
21+
self.metrics_file = self.branch_dir / "metrics.json"
22+
self._setup_branch()
23+
24+
def _setup_branch(self):
25+
"""Setup branch directory structure"""
26+
ensure_directory(str(self.branch_dir))
27+
ensure_directory(str(self.experiments_dir))
28+
29+
if not self.metrics_file.exists():
30+
write_json_file(str(self.metrics_file), {
31+
"branch_name": self.name,
32+
"base_branch": self.base_branch,
33+
"created_at": datetime.now().isoformat(),
34+
"experiments": [],
35+
"best_metrics": {},
36+
"status": "active"
37+
})
38+
39+
def add_experiment(
40+
self,
41+
experiment_name: str,
42+
model_id: str,
43+
dataset_id: str,
44+
hyperparameters: Dict[str, Any],
45+
metrics: Dict[str, float],
46+
description: str = ""
47+
) -> Dict[str, Any]:
48+
"""Add an experiment to this branch"""
49+
50+
experiment_data = {
51+
"id": f"{self.name}_{experiment_name}",
52+
"name": experiment_name,
53+
"branch": self.name,
54+
"model_id": model_id,
55+
"dataset_id": dataset_id,
56+
"hyperparameters": hyperparameters,
57+
"metrics": metrics,
58+
"description": description,
59+
"created_at": datetime.now().isoformat(),
60+
"status": "completed"
61+
}
62+
63+
# Save experiment data
64+
experiment_file = self.experiments_dir / f"{experiment_name}.json"
65+
write_json_file(str(experiment_file), experiment_data)
66+
67+
# Update branch metrics
68+
self._update_branch_metrics(experiment_data)
69+
70+
print(f"✅ Experiment added to branch '{self.name}': {experiment_name}")
71+
return experiment_data
72+
73+
def _update_branch_metrics(self, experiment: Dict[str, Any]):
74+
"""Update branch-level metrics"""
75+
branch_data = read_json_file(str(self.metrics_file))
76+
branch_data["experiments"].append(experiment["id"])
77+
78+
# Update best metrics
79+
for metric, value in experiment["metrics"].items():
80+
if metric not in branch_data["best_metrics"]:
81+
branch_data["best_metrics"][metric] = value
82+
else:
83+
# Keep the best value (assuming higher is better)
84+
if value > branch_data["best_metrics"][metric]:
85+
branch_data["best_metrics"][metric] = value
86+
87+
write_json_file(str(self.metrics_file), branch_data)
88+
89+
def get_experiments(self) -> List[Dict[str, Any]]:
90+
"""Get all experiments in this branch"""
91+
experiments = []
92+
93+
for experiment_file in self.experiments_dir.glob("*.json"):
94+
experiment = read_json_file(str(experiment_file))
95+
if experiment:
96+
experiments.append(experiment)
97+
98+
return sorted(experiments, key=lambda x: x["created_at"], reverse=True)
99+
100+
def get_best_experiment(self, metric: str) -> Optional[Dict[str, Any]]:
101+
"""Get the best experiment for a specific metric"""
102+
experiments = self.get_experiments()
103+
if not experiments:
104+
return None
105+
106+
best_experiment = None
107+
best_value = float('-inf')
108+
109+
for experiment in experiments:
110+
value = experiment.get("metrics", {}).get(metric, float('-inf'))
111+
if value > best_value:
112+
best_value = value
113+
best_experiment = experiment
114+
115+
return best_experiment
116+
117+
def get_metrics_summary(self) -> Dict[str, Any]:
118+
"""Get summary of all metrics in this branch"""
119+
experiments = self.get_experiments()
120+
if not experiments:
121+
return {}
122+
123+
all_metrics = set()
124+
for experiment in experiments:
125+
all_metrics.update(experiment.get("metrics", {}).keys())
126+
127+
summary = {}
128+
for metric in all_metrics:
129+
values = [exp.get("metrics", {}).get(metric, 0) for exp in experiments]
130+
summary[metric] = {
131+
"count": len(values),
132+
"min": min(values),
133+
"max": max(values),
134+
"avg": sum(values) / len(values),
135+
"std": self._calculate_std(values)
136+
}
137+
138+
return summary
139+
140+
def _calculate_std(self, values: List[float]) -> float:
141+
"""Calculate standard deviation"""
142+
if len(values) < 2:
143+
return 0.0
144+
145+
mean = sum(values) / len(values)
146+
variance = sum((x - mean) ** 2 for x in values) / (len(values) - 1)
147+
return variance ** 0.5
148+
149+
class ExperimentManager:
150+
"""Manages experiment branches and comparisons"""
151+
152+
def __init__(self, repo_path: str = "."):
153+
self.repo_path = Path(repo_path)
154+
self.branches_dir = self.repo_path / ".modelsync" / "branches"
155+
ensure_directory(str(self.branches_dir))
156+
157+
def create_branch(self, name: str, base_branch: Optional[str] = None) -> ExperimentBranch:
158+
"""Create a new experiment branch"""
159+
if self.branch_exists(name):
160+
raise ValueError(f"Branch '{name}' already exists")
161+
162+
branch = ExperimentBranch(name, base_branch, str(self.repo_path))
163+
print(f"✅ Created experiment branch: {name}")
164+
return branch
165+
166+
def get_branch(self, name: str) -> Optional[ExperimentBranch]:
167+
"""Get an existing branch"""
168+
if not self.branch_exists(name):
169+
return None
170+
171+
return ExperimentBranch(name, repo_path=str(self.repo_path))
172+
173+
def branch_exists(self, name: str) -> bool:
174+
"""Check if branch exists"""
175+
return (self.branches_dir / name).exists()
176+
177+
def list_branches(self) -> List[str]:
178+
"""List all experiment branches"""
179+
if not self.branches_dir.exists():
180+
return []
181+
182+
return [d.name for d in self.branches_dir.iterdir() if d.is_dir()]
183+
184+
def compare_branches(self, branch_names: List[str], metric: str) -> Dict[str, Any]:
185+
"""Compare multiple branches by a specific metric"""
186+
branches = [self.get_branch(name) for name in branch_names if self.get_branch(name)]
187+
188+
if len(branches) < 2:
189+
return {"error": "Need at least 2 branches to compare"}
190+
191+
comparison = {
192+
"metric": metric,
193+
"branches": [],
194+
"best_branch": None,
195+
"worst_branch": None
196+
}
197+
198+
best_value = float('-inf')
199+
worst_value = float('inf')
200+
201+
for branch in branches:
202+
experiments = branch.get_experiments()
203+
if not experiments:
204+
continue
205+
206+
# Calculate average metric value for this branch
207+
values = [exp.get("metrics", {}).get(metric, 0) for exp in experiments]
208+
avg_value = sum(values) / len(values) if values else 0
209+
210+
branch_data = {
211+
"name": branch.name,
212+
"experiment_count": len(experiments),
213+
"avg_metric_value": avg_value,
214+
"best_experiment": branch.get_best_experiment(metric)
215+
}
216+
217+
comparison["branches"].append(branch_data)
218+
219+
if avg_value > best_value:
220+
best_value = avg_value
221+
comparison["best_branch"] = branch.name
222+
223+
if avg_value < worst_value:
224+
worst_value = avg_value
225+
comparison["worst_branch"] = branch.name
226+
227+
return comparison
228+
229+
def merge_branch(
230+
self,
231+
source_branch: str,
232+
target_branch: str,
233+
merge_strategy: str = "best_experiment"
234+
) -> bool:
235+
"""Merge experiment branch into another branch"""
236+
237+
source = self.get_branch(source_branch)
238+
target = self.get_branch(target_branch)
239+
240+
if not source or not target:
241+
print("❌ Source or target branch not found")
242+
return False
243+
244+
if merge_strategy == "best_experiment":
245+
# Find best experiment in source branch
246+
experiments = source.get_experiments()
247+
if not experiments:
248+
print("❌ No experiments in source branch")
249+
return False
250+
251+
# Get best experiment by average score
252+
best_experiment = None
253+
best_score = float('-inf')
254+
255+
for experiment in experiments:
256+
avg_score = sum(experiment.get("metrics", {}).values()) / len(experiment.get("metrics", {})) if experiment.get("metrics") else 0
257+
if avg_score > best_score:
258+
best_score = avg_score
259+
best_experiment = experiment
260+
261+
if best_experiment:
262+
# Add best experiment to target branch
263+
target.add_experiment(
264+
f"{best_experiment['name']}_merged",
265+
best_experiment["model_id"],
266+
best_experiment["dataset_id"],
267+
best_experiment["hyperparameters"],
268+
best_experiment["metrics"],
269+
f"Merged from {source_branch}: {best_experiment['description']}"
270+
)
271+
272+
print(f"✅ Merged branch '{source_branch}' into '{target_branch}'")
273+
return True
274+
275+
def delete_branch(self, name: str) -> bool:
276+
"""Delete an experiment branch"""
277+
if not self.branch_exists(name):
278+
print(f"❌ Branch '{name}' not found")
279+
return False
280+
281+
shutil.rmtree(self.branches_dir / name)
282+
print(f"✅ Deleted branch: {name}")
283+
return True

0 commit comments

Comments
 (0)