Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions launcher.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ NSTEPS=100_000
torchrun --nproc_per_node $NGPUS \
trainval.py \
-e mlm \
--wandb-entity-name bigcode \
--wandb-project-name tf_encoder \
--wandb-run-name mlm \
--wandb-log-gradients false \
--steps $NSTEPS \
-sb $PATH_TO_LOG \
--train_data_name $TRAIN_DATA_NAME \
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ pandas==1.5.3
datasets==2.10.1
sentencepiece==0.1.97
transformers==4.27.2
accelerate==0.17.1
beir==1.0.1
mteb==1.0.1
seaborn
Expand Down
32 changes: 22 additions & 10 deletions src/datasets_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import torch
from torch.utils.data import Dataset
import datasets

# Workaround toolkit misreporting available disk space.
datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True
from datasets import load_dataset, load_from_disk
from datasets.builder import DatasetBuildError
from transformers import AutoTokenizer
from src.preprocessing_utils import (
perturb_tokens,
Expand Down Expand Up @@ -129,23 +133,31 @@ def get_dataset(
"""
try:
base_dataset = load_dataset(
dataset_name, use_auth_token=True, cache_dir=path_to_cache, split=split
dataset_name,
use_auth_token=True,
cache_dir=path_to_cache,
split=split,
)
except DatasetBuildError:
# Try to specify data files. Specific for The Stack.
base_dataset = load_dataset(
dataset_name,
use_auth_token=True,
cache_dir=path_to_cache,
data_files="sample.parquet",
split=split,
)
except FileNotFoundError:
try:
base_dataset = load_dataset(
dataset_name,
use_auth_token=True,
cache_dir=path_to_cache,
)[split]
except FileNotFoundError:
base_dataset = load_from_disk(path_to_cache)
# Try to load from disk if above failed.
base_dataset = load_from_disk(path_to_cache)

if force_preprocess:
base_dataset.cleanup_cache_files()

base_dataset = base_dataset.shuffle(seed=42)

if maximum_row_cout is not None:
base_dataset = base_dataset.shuffle(seed=42).select(
base_dataset = base_dataset.select(
range(min(len(base_dataset), maximum_row_cout))
)

Expand Down
23 changes: 22 additions & 1 deletion src/hf_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import wandb
import torch
from torch.utils.data.dataset import Dataset
import transformers
Expand All @@ -16,6 +17,9 @@
retrieval_eval,
)
from src.datasets_loader import Collator
from src.logging_callback import LoggingCallback

from functools import partial


def compute_metrics(eval_pred: PredictionOutput) -> Dict[str, float]:
Expand Down Expand Up @@ -221,6 +225,10 @@ def get_trainer(
log_every: int = 100,
local_rank: int = 0,
deepspeed_cfg_path: str = None,
wandb_entity_name: str = None,
wandb_project_name: str = None,
wandb_run_name: str = None,
wandb_log_grads: bool = False,
) -> CustomTrainer:
"""Intanstiates Trainer object.

Expand All @@ -234,6 +242,10 @@ def get_trainer(
log_every (int): Logging interval.
local_rank (int): Device id for distributed training.
deepspeed_cfg_path (str, Optional): Optional path to deepspeed config.
wandb_entity_name (str, optional): Wandb entity. Defaults to None.
wandb_project_name (str, optional): Project name for wandb. Defaults to None.
wandb-run-name (str, optional): Run id name for wandb. Defaults to None.
wandb_log_grads (bool, optional): Whether to write grads on wandb logs. Defaults to False.

Returns:
CustomTrainer: Trainer object.
Expand Down Expand Up @@ -265,18 +277,27 @@ def get_trainer(
save_strategy="steps",
save_steps=log_every,
evaluation_strategy="steps",
report_to="wandb",
# report_to="wandb",
)

encoder = get_encoder(exp_dict=exp_dict)

wandb.init(
name=wandb_run_name,
entity=wandb_entity_name,
project=wandb_project_name,
)

trainer = CustomTrainer(
model=encoder,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
compute_metrics=compute_metrics,
data_collator=collate_fn,
callbacks=[
LoggingCallback(log_grads=wandb_log_grads),
],
)

return trainer
21 changes: 21 additions & 0 deletions src/logging_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from transformers.integrations import WandbCallback


class LoggingCallback(WandbCallback):
"""
Overrigding WandbCallback to optionally turn off gradient logging.
"""

def __init__(self, log_grads: bool):

super().__init__()

self.log_grads = log_grads

def setup(self, args, state, model, **kwargs):

super().setup(args, state, model, **kwargs)
_watch_model = "all" if self.log_grads else "parameters"
self._wandb.watch(
model, log=_watch_model, log_freq=max(100, args.logging_steps)
)
36 changes: 35 additions & 1 deletion src/training_args.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
import argparse


def parse_bool_flag(s: str) -> bool:
"""Parse boolean arguments from the command line.

Args:
s (str): Input arg string.

Returns:
bool: _description_
"""
_FALSY_STRINGS = {"off", "false", "0"}
_TRUTHY_STRINGS = {"on", "true", "1"}
if s.lower() in _FALSY_STRINGS:
return False
elif s.lower() in _TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("Invalid value for a boolean flag")


def parse_args():
# Specify arguments regarding save directory and job scheduler
parser = argparse.ArgumentParser()
parser.add_argument(
"-e",
Expand Down Expand Up @@ -45,6 +63,22 @@ def parse_args():
type=int,
help="Number of iterations to wait before logging training scores.",
)
parser.add_argument(
"--wandb-entity-name",
type=str,
default="bigcode",
help="Name of wandb entity for reporting.",
)
parser.add_argument(
"--wandb-project-name", type=str, default=None, help="Name of wandb project."
)
parser.add_argument("--wandb-run-name", type=str, default=None, help="Name of run.")
parser.add_argument(
"--wandb-log-gradients",
type=parse_bool_flag,
default="false",
help="Whether to write gradients to wandb logs.",
)
parser.add_argument(
"--dist_url",
default="env://",
Expand Down
4 changes: 4 additions & 0 deletions trainval_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def train(exp_dict, savedir, args):
valid_dataset=gfg_test_data,
collate_fn=collate_fn,
log_every=args.log_every,
wandb_entity_name=args.wandb_entity_name,
wandb_project_name=args.wandb_project_name,
wandb_run_name=args.wandb_run_name,
wandb_log_grads=args.wandb_log_gradients,
local_rank=args.local_rank,
deepspeed_cfg_path=args.deepspeed,
)
Expand Down