Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions exp_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@
"mlm_masking_probability": 0.15,
},
"contrastive_local": {
"alpha": 0.5,
"alpha": 0.4,
"initial_temperature_coef": 1.0725, # Matches initial value in clip.
"local_contrastive_loss": True,
"mlm_masking_probability": 0.15,
"contrastive_masking_probability": 0.3,
"contrastive_masking_probability": 0.2,
},
"contrastive_global": {
"alpha": 0.5,
"alpha": 0.4,
"initial_temperature_coef": 1.0725, # Matches initial value in clip.
"local_contrastive_loss": False,
"mlm_masking_probability": 0.15,
"contrastive_masking_probability": 0.3,
"contrastive_masking_probability": 0.2,
},
}

Expand Down
15 changes: 8 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
--find-links https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
numpy==1.22.4
torch==2.0.0+cu117
numpy==1.24.2
matplotlib==3.4.3
sklearn==1.1.2
pandas==1.4.3
scikit-learn==1.2.2
pandas==1.5.3
datasets==2.10.1
sentencepiece==0.1.97
transformers==4.21.1
accelerate==0.16.0
transformers==4.27.2
accelerate==0.17.1
beir==1.0.1
mteb==1.0.1
seaborn
wandb
huggingface-cli
haven-ai @ git+https://github.com/haven-ai/haven-ai@00fe4e3a10bfe09fef361836b8fcfcffcecd3451
haven-ai
10 changes: 2 additions & 8 deletions src/datasets_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,7 @@ def get_dataset(
"""
try:
base_dataset = load_dataset(
dataset_name,
use_auth_token=True,
cache_dir=path_to_cache,
split=split
dataset_name, use_auth_token=True, cache_dir=path_to_cache, split=split
)
except FileNotFoundError:
try:
Expand Down Expand Up @@ -168,13 +165,10 @@ def get_dataset(
split_preproc_key
]

base_dataset = base_dataset.map(
pre_proc_fn(maximum_raw_length), num_proc=96
)
base_dataset = base_dataset.map(pre_proc_fn(maximum_raw_length), num_proc=96)

base_dataset = base_dataset.shuffle(seed=42)


if "train" in split_preproc_key:
return RandomlyPairedDataset(base_dataset)
else:
Expand Down
36 changes: 36 additions & 0 deletions src/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.distributed as dist


class AllGather(torch.autograd.Function):
"""
all_gather with gradient back-propagation
Adapted from https://github.com/Lightning-AI/lightning-bolts/blob/5577453a6d7072724d9ae24184daf8f45d4baff7/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20-L40
"""

@staticmethod
def forward(ctx, tensor):
ctx.batch_size = tensor.shape[0]

gathered_tensor = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]

torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.cat(gathered_tensor, 0)

return gathered_tensor

@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(
grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False
)

idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]


all_gather = AllGather.apply
8 changes: 5 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Union
import torch
from accelerate.utils.operations import _gpu_gather
from src.distributed_utils import all_gather


class TempCoef(torch.nn.Module):
Expand Down Expand Up @@ -64,7 +64,8 @@ def gather_embeddings(
1,
)

embedding_dist = _gpu_gather(embedding)
# Gather embeddings across devices
embedding_dist = all_gather(embedding)

embedding_1_dist = embedding_dist[:, 0, :]
embedding_2_dist = embedding_dist[:, 1, :]
Expand Down Expand Up @@ -96,9 +97,10 @@ def clip_contrastive_loss(
# Gathers embeddings across devices.
emb_1_dist, emb_2_dist = gather_embeddings(emb_1, emb_2)

# Compute cosine similarity matrix
# Compute similarity matrix
similarities = emb_1_dist @ emb_2_dist.T

# Multiply similarity matrix by temperature
similarities = temperature_coef(similarities)

# Matching representations of positive pairs assumed to be located at the main
Expand Down