Skip to content

Commit 3e151a0

Browse files
committed
Create dataset_storage.py
1 parent 48e140c commit 3e151a0

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""
2+
Dataset storage and versioning for ModelSync
3+
"""
4+
5+
import os
6+
import hashlib
7+
import json
8+
import shutil
9+
from pathlib import Path
10+
from typing import Dict, List, Optional, Any, Union
11+
from datetime import datetime
12+
import boto3
13+
from google.cloud import storage
14+
from modelsync.utils.helpers import calculate_file_hash, ensure_directory, write_json_file, read_json_file
15+
16+
class DatasetStorage:
17+
"""Manages dataset storage and versioning with cloud support"""
18+
19+
def __init__(self, repo_path: str = ".", config: Optional[Dict] = None):
20+
self.repo_path = Path(repo_path)
21+
self.storage_dir = self.repo_path / ".modelsync" / "storage" / "datasets"
22+
self.config = config or {}
23+
self._setup_cloud_clients()
24+
25+
def _setup_cloud_clients(self):
26+
"""Setup cloud storage clients"""
27+
self.s3_client = None
28+
self.gcs_client = None
29+
30+
# AWS S3
31+
if self.config.get("aws", {}).get("access_key_id"):
32+
self.s3_client = boto3.client(
33+
's3',
34+
aws_access_key_id=self.config["aws"]["access_key_id"],
35+
aws_secret_access_key=self.config["aws"]["secret_access_key"],
36+
region_name=self.config["aws"].get("region", "us-east-1")
37+
)
38+
39+
# Google Cloud Storage
40+
if self.config.get("gcs", {}).get("project_id"):
41+
self.gcs_client = storage.Client(
42+
project=self.config["gcs"]["project_id"]
43+
)
44+
45+
def add_dataset(
46+
self,
47+
dataset_path: str,
48+
dataset_name: str,
49+
description: str = "",
50+
tags: List[str] = None,
51+
cloud_storage: Optional[str] = None,
52+
deduplicate: bool = True
53+
) -> Dict[str, Any]:
54+
"""Add a dataset to version control"""
55+
56+
dataset_path = Path(dataset_path)
57+
if not dataset_path.exists():
58+
raise FileNotFoundError(f"Dataset not found: {dataset_path}")
59+
60+
# Calculate dataset hash for deduplication
61+
dataset_hash = self._calculate_dataset_hash(dataset_path)
62+
63+
# Check if dataset already exists (deduplication)
64+
if deduplicate:
65+
existing = self._find_existing_dataset(dataset_hash)
66+
if existing:
67+
print(f"📦 Dataset already exists: {existing['name']} ({dataset_hash[:8]})")
68+
return existing
69+
70+
# Create dataset metadata
71+
dataset_metadata = {
72+
"id": dataset_hash[:16],
73+
"name": dataset_name,
74+
"description": description,
75+
"tags": tags or [],
76+
"original_path": str(dataset_path),
77+
"hash": dataset_hash,
78+
"size": self._calculate_dataset_size(dataset_path),
79+
"file_count": self._count_files(dataset_path),
80+
"created_at": datetime.now().isoformat(),
81+
"cloud_storage": cloud_storage,
82+
"storage_info": {}
83+
}
84+
85+
# Store dataset locally
86+
self._store_dataset_locally(dataset_path, dataset_metadata)
87+
88+
# Upload to cloud if specified
89+
if cloud_storage:
90+
self._upload_to_cloud(dataset_path, dataset_metadata, cloud_storage)
91+
92+
# Save metadata
93+
self._save_dataset_metadata(dataset_metadata)
94+
95+
print(f"✅ Dataset added: {dataset_name} ({dataset_hash[:8]})")
96+
return dataset_metadata
97+
98+
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
99+
"""Get dataset metadata by ID"""
100+
metadata_file = self.storage_dir / "metadata" / f"{dataset_id}.json"
101+
if metadata_file.exists():
102+
return read_json_file(str(metadata_file))
103+
return None
104+
105+
def list_datasets(self, tags: List[str] = None) -> List[Dict[str, Any]]:
106+
"""List all datasets, optionally filtered by tags"""
107+
datasets = []
108+
metadata_dir = self.storage_dir / "metadata"
109+
110+
if not metadata_dir.exists():
111+
return datasets
112+
113+
for metadata_file in metadata_dir.glob("*.json"):
114+
dataset = read_json_file(str(metadata_file))
115+
if dataset:
116+
if not tags or any(tag in dataset.get("tags", []) for tag in tags):
117+
datasets.append(dataset)
118+
119+
return sorted(datasets, key=lambda x: x["created_at"], reverse=True)
120+
121+
def download_dataset(self, dataset_id: str, target_path: str) -> bool:
122+
"""Download dataset to local path"""
123+
dataset = self.get_dataset(dataset_id)
124+
if not dataset:
125+
return False
126+
127+
# Check if dataset exists locally
128+
local_path = self.storage_dir / "datasets" / dataset_id
129+
if local_path.exists():
130+
shutil.copytree(local_path, target_path, dirs_exist_ok=True)
131+
return True
132+
133+
# Download from cloud if available
134+
if dataset.get("cloud_storage"):
135+
return self._download_from_cloud(dataset, target_path)
136+
137+
return False
138+
139+
def _calculate_dataset_hash(self, dataset_path: Path) -> str:
140+
"""Calculate hash for entire dataset"""
141+
hashes = []
142+
143+
if dataset_path.is_file():
144+
hashes.append(calculate_file_hash(str(dataset_path)))
145+
else:
146+
for file_path in sorted(dataset_path.rglob("*")):
147+
if file_path.is_file():
148+
hashes.append(calculate_file_hash(str(file_path)))
149+
150+
# Combine all hashes
151+
combined = "".join(hashes)
152+
return hashlib.sha256(combined.encode()).hexdigest()
153+
154+
def _calculate_dataset_size(self, dataset_path: Path) -> int:
155+
"""Calculate total size of dataset"""
156+
if dataset_path.is_file():
157+
return dataset_path.stat().st_size
158+
159+
total_size = 0
160+
for file_path in dataset_path.rglob("*"):
161+
if file_path.is_file():
162+
total_size += file_path.stat().st_size
163+
164+
return total_size
165+
166+
def _count_files(self, dataset_path: Path) -> int:
167+
"""Count number of files in dataset"""
168+
if dataset_path.is_file():
169+
return 1
170+
171+
return len([f for f in dataset_path.rglob("*") if f.is_file()])
172+
173+
def _find_existing_dataset(self, dataset_hash: str) -> Optional[Dict[str, Any]]:
174+
"""Find existing dataset by hash"""
175+
for dataset in self.list_datasets():
176+
if dataset["hash"] == dataset_hash:
177+
return dataset
178+
return None
179+
180+
def _store_dataset_locally(self, dataset_path: Path, metadata: Dict[str, Any]):
181+
"""Store dataset in local storage"""
182+
dataset_id = metadata["id"]
183+
local_storage_path = self.storage_dir / "datasets" / dataset_id
184+
185+
if dataset_path.is_file():
186+
local_storage_path.parent.mkdir(parents=True, exist_ok=True)
187+
shutil.copy2(dataset_path, local_storage_path)
188+
else:
189+
shutil.copytree(dataset_path, local_storage_path, dirs_exist_ok=True)
190+
191+
def _upload_to_cloud(self, dataset_path: Path, metadata: Dict[str, Any], cloud_type: str):
192+
"""Upload dataset to cloud storage"""
193+
dataset_id = metadata["id"]
194+
195+
if cloud_type == "s3" and self.s3_client:
196+
bucket = self.config["aws"]["bucket"]
197+
key = f"datasets/{dataset_id}"
198+
199+
if dataset_path.is_file():
200+
self.s3_client.upload_file(str(dataset_path), bucket, key)
201+
else:
202+
# Upload directory
203+
for file_path in dataset_path.rglob("*"):
204+
if file_path.is_file():
205+
relative_path = file_path.relative_to(dataset_path)
206+
s3_key = f"datasets/{dataset_id}/{relative_path}"
207+
self.s3_client.upload_file(str(file_path), bucket, s3_key)
208+
209+
metadata["storage_info"]["s3"] = {
210+
"bucket": bucket,
211+
"key": key
212+
}
213+
214+
elif cloud_type == "gcs" and self.gcs_client:
215+
bucket_name = self.config["gcs"]["bucket"]
216+
bucket = self.gcs_client.bucket(bucket_name)
217+
218+
if dataset_path.is_file():
219+
blob = bucket.blob(f"datasets/{dataset_id}")
220+
blob.upload_from_filename(str(dataset_path))
221+
else:
222+
# Upload directory
223+
for file_path in dataset_path.rglob("*"):
224+
if file_path.is_file():
225+
relative_path = file_path.relative_to(dataset_path)
226+
blob_name = f"datasets/{dataset_id}/{relative_path}"
227+
blob = bucket.blob(blob_name)
228+
blob.upload_from_filename(str(file_path))
229+
230+
metadata["storage_info"]["gcs"] = {
231+
"bucket": bucket_name,
232+
"prefix": f"datasets/{dataset_id}"
233+
}
234+
235+
def _download_from_cloud(self, dataset: Dict[str, Any], target_path: str) -> bool:
236+
"""Download dataset from cloud storage"""
237+
storage_info = dataset.get("storage_info", {})
238+
239+
if "s3" in storage_info and self.s3_client:
240+
s3_info = storage_info["s3"]
241+
# Download from S3
242+
# Implementation depends on whether it's a file or directory
243+
return True
244+
245+
elif "gcs" in storage_info and self.gcs_client:
246+
gcs_info = storage_info["gcs"]
247+
# Download from GCS
248+
return True
249+
250+
return False
251+
252+
def _save_dataset_metadata(self, metadata: Dict[str, Any]):
253+
"""Save dataset metadata"""
254+
metadata_dir = self.storage_dir / "metadata"
255+
ensure_directory(str(metadata_dir))
256+
257+
metadata_file = metadata_dir / f"{metadata['id']}.json"
258+
write_json_file(str(metadata_file), metadata)

0 commit comments

Comments
 (0)