-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathsetup.py
More file actions
40 lines (32 loc) · 1.18 KB
/
setup.py
File metadata and controls
40 lines (32 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import logging
import os
from datasets import load_dataset
from typing import Iterator
from commit0.harness.utils import (
clone_repo,
)
from commit0.harness.constants import BASE_BRANCH, RepoInstance, SPLIT
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def main(
dataset_name: str,
dataset_split: str,
repo_split: str,
base_dir: str,
) -> None:
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
for example in dataset:
repo_name = example["repo"].split("/")[-1]
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
continue
clone_url = f"https://github.com/{example['repo']}.git"
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
branch = dataset_name.split("/")[-1]
repo = clone_repo(clone_url, clone_dir, branch, logger)
if BASE_BRANCH in repo.branches:
repo.git.branch("-d", BASE_BRANCH)
repo.git.checkout("-b", BASE_BRANCH)
logger.info(f"Checked out the base branch: {BASE_BRANCH}")
__all__ = []