Skip to content
Open
Prev Previous commit
Next Next commit
finish fp16 + 1bit, check diff2.py train2.py
  • Loading branch information
pingbowen23 committed Mar 12, 2024
commit 86d443d17d33a221a413aade0e5f1cafa0d77e0b
62 changes: 37 additions & 25 deletions bitdelta/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class BinaryDiff(nn.Module):
def __init__(self, base, finetune):
super().__init__()
diff = finetune - base
# diff = decomposition(diff, 2048)
diff = decomposition(diff, st=64, ed=1024)
quantile = diff.float().abs().mean()

mask = torch.ones_like(diff)
Expand Down Expand Up @@ -66,17 +66,28 @@ def compress_submodule(name, subname, module, submodule):
setattr(module, subname, compressed)

# TODO: this can be parallelized
# flag = False
for name, module in finetuned_compressed_model.named_modules():
if "mlp" in name or "self_attn" in name:

if Pass(layers,name) == True:
continue

# if flag == True:
# break

if "self_attn" in name:
for subname, submodule in module.named_children():
if "proj" in subname:
compress_submodule(name, subname, module, submodule)

def save_diff(finetuned_compressed_model, save_dir,layers=None):
elif "mlp" in name:
with torch.no_grad():
for subname, submodule in module.named_children():
if "proj" in subname:
base_weight = base_model.get_submodule(f"{name}.{subname}").weight.detach().to(submodule.weight.device)
finetuned_weight = finetuned_model.get_submodule(f"{name}.{subname}").weight.detach().to(submodule.weight.device)
delta = decomposition(finetuned_weight - base_weight,dim=int(128 * 1.45))
finetuned_compressed_model.get_submodule(f"{name}.{subname}").weight.copy_(base_weight + delta.to(torch.bfloat16))
# flag = True
# import pdb; pdb.set_trace()
# break

def save_diff(finetuned_compressed_model, save_dir,layers=None,ori_diff=None):
diff_dict = {}

for name, module in finetuned_compressed_model.named_modules():
Expand All @@ -92,9 +103,10 @@ def save_diff(finetuned_compressed_model, save_dir,layers=None):
torch.save(diff_dict, save_dir)

@torch.no_grad()
def load_diff(model, diff_dir):
def load_diff(model, diff_dir,ori_diff):
device = model.device
diff_dict = torch.load(diff_dir)
# ori_diff = torch.load(ori_diff)

for name, module in model.named_modules():
if name + ".mask" in diff_dict:
Expand All @@ -104,13 +116,15 @@ def load_diff(model, diff_dir):
# setattr(module, "mask", mask)
# setattr(module, "coeff", coeff)
weight = (unpack(mask)*2-1) * coeff

if "mlp" in name:
weight = decomposition(weight, 1024)
weight_fp16 = decomposition(ori_diff[name + ".weight"].to(torch.float32), dim=64).to(torch.bfloat16)
# import pdb; pdb.set_trace()

module.weight.add_(weight.T.to(module.weight.dtype))
module.weight.add_(weight_fp16.to(module.weight.dtype) + weight.T.to(module.weight.dtype))
elif name + ".weight" in diff_dict:
module.weight = nn.Parameter(diff_dict[name + ".weight"].to(device).to(module.weight.dtype))

# if "mlp" in name:
# import pdb; pdb.set_trace()

elif name + '.A' in diff_dict:
A = diff_dict[name + '.A'].to(device)
Expand All @@ -121,17 +135,18 @@ def load_diff(model, diff_dir):

model.config.vocab_size = model.lm_head.weight.size(0)

def decomposition(masked_input_tensor,dim):
# if "mlp" in name:
# dim = int(dim * 1.45)
def decomposition(masked_input_tensor,dim=None,st=None,ed=None):
U , S , V = torch.svd(masked_input_tensor.to(torch.float32))

if dim is not None:
U , S , V = U[:, :dim],S[:dim] ,V[:, :dim]

if st is not None and ed is not None:
U , S , V = U[:, st:ed],S[st:ed] ,V[:, st:ed]

U , S , V = torch.svd(masked_input_tensor)
# total_sum , partial_sum = torch.sum(S) , torch.sum(S[:128])
# import pdb; pdb.set_trace()
U , S , V = U[:, :dim],S[:dim] ,V[:, :dim]
return torch.mm(torch.mm(U, torch.diag(S)), V.t())

def save_full_model(base_model_name, finetuned_model_name, diff_dir, save_dir, device,layers=None):
def save_full_model(base_model_name, finetuned_model_name, diff_dir, save_dir, device,layers=None,ori_diff=None):
base_model = get_model(base_model_name, device)
tokenizer = get_tokenizer(finetuned_model_name)

Expand All @@ -150,17 +165,14 @@ def save_full_model(base_model_name, finetuned_model_name, diff_dir, save_dir, d
# # import pdb; pdb.set_trace()
# params[k] = decomposition(delta.to(torch.float32), dim).to(torch.bfloat16)

# import pdb; pdb.set_trace()
# dict(base_model.named_parameters())['model.layers.0.self_attn.o_proj.weight']

# with torch.no_grad():
# for param in params:
# base_model.get_submodule(param.replace('.weight',"")).weight.add_(params[param].detach().to(device))

# import pdb; pdb.set_trace()
load_diff(base_model, diff_dir)
load_diff(base_model, diff_dir,ori_diff=ori_diff)


base_model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

Expand Down
173 changes: 173 additions & 0 deletions bitdelta/diff2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import torch
import torch.nn as nn
import gc

from bitdelta.binary_gemm_kernel import pack, unpack, binary_bmm
from bitdelta.utils import get_model, get_tokenizer

class BinaryDiff(nn.Module):
def __init__(self, weight):
super().__init__()
diff = weight
quantile = diff.float().abs().mean()

mask = torch.ones_like(diff)
mask[diff < 0] = 0
mask = pack(mask.bool().T)

self.register_buffer("mask", mask)
# self.register_buffer("base", base.T)
self.register_parameter(
"coeff",
nn.Parameter(
torch.tensor(
quantile,
dtype=torch.float32,
requires_grad=True,
device=weight.device,
)
),
)
# del base, finetune, diff

def forward(self, x):
# print(x.shape, self.base.shape, self.coeff.shape, self.mask.shape)
# [B, seq, in] @ [in, out] + [B, seq, in] @ [B, in/32, out]

# TODO: This can be faster
repeated_mask = self.mask.unsqueeze(0).repeat(x.size(0), 1, 1)
return x @ self.base + self.coeff * binary_bmm(x, repeated_mask)

def Pass(layers=None,name=None):
if layers is not None:
for layer in layers:
if layer in name:
return True
return False


def compress_diff(base_model, finetuned_model, finetuned_compressed_model,save_dir,layers=None):
def compress_submodule(name, subname, module, submodule):
target_device = submodule.weight.device

base_weight = base_model.get_submodule(f"{name}.{subname}").weight.detach().to(target_device)
finetuned_weight = finetuned_model.get_submodule(f"{name}.{subname}").weight.detach().to(target_device)

compressed = BinaryDiff(
base=base_weight,
finetune=finetuned_weight,
).to(target_device)

del submodule, base_weight
setattr(module, subname, None)
gc.collect()
torch.cuda.empty_cache()
setattr(module, subname, compressed)

# TODO: this can be parallelized
for name, module in finetuned_compressed_model.named_modules():

if "self_attn" in name:
for subname, submodule in module.named_children():
if "proj" in subname:
base_weight = base_model.get_submodule(f"{name}.{subname}").weight.detach().to(submodule.weight.device)
finetuned_weight = finetuned_model.get_submodule(f"{name}.{subname}").weight.detach().to(submodule.weight.device)
# compress_submodule(name, subname, module, submodule)
U,S,V = decomposition(finetuned_weight - base_weight,dim=1024)

compressed_U, compressed_V = BinaryDiff(weight=U[:,64:]).to(finetuned_weight.device), BinaryDiff(weight=V[:,64:]).to(finetuned_weight.device)
U_mask, U_coeff, V_mask, V_coeff = compressed_U.mask, compressed_U.coeff, compressed_V.mask, compressed_V.coeff
weight_U , weight_V = (unpack(U_mask)*2-1) * U_coeff, (unpack(V_mask)*2-1) * V_coeff
# import pdb; pdb.set_trace()
U[:,64:] , V[:,64:] = weight_U.T, weight_V.T # 不确定是否有bug
delta = U @ torch.diag(S) @ V.t()
with torch.no_grad():
finetuned_model.get_submodule(f"{name}.{subname}").weight.copy_(base_weight + delta.to(torch.bfloat16))


elif "mlp" in name:
with torch.no_grad():
for subname, submodule in module.named_children():
if "proj" in subname:
base_weight = base_model.get_submodule(f"{name}.{subname}").weight.detach().to(submodule.weight.device)
finetuned_weight = finetuned_model.get_submodule(f"{name}.{subname}").weight.detach().to(submodule.weight.device)
U,S,V = decomposition(finetuned_weight - base_weight,dim=int(128 * 1.45))
delta = torch.mm(torch.mm(U, torch.diag(S)), V.t())
finetuned_model.get_submodule(f"{name}.{subname}").weight.copy_(base_weight + delta.to(torch.bfloat16))


finetuned_model.save_pretrained(save_dir)

def save_diff(finetuned_compressed_model, save_dir,layers=None,ori_diff=None):
diff_dict = {}

for name, module in finetuned_compressed_model.named_modules():
if isinstance(module, BinaryDiff):
# diff_dict[name + ".mask"] = (module.mask == 1).bool().cpu()
diff_dict[name + ".mask"] = module.mask.cpu()
diff_dict[name + ".coeff"] = module.coeff.cpu()

for name, param in finetuned_compressed_model.named_parameters():
if param.requires_grad:
diff_dict[name] = param.cpu()

torch.save(diff_dict, save_dir)

@torch.no_grad()
def load_diff(model, diff_dir,ori_diff):
device = model.device
diff_dict = torch.load(diff_dir)
# ori_diff = torch.load(ori_diff)

for name, module in model.named_modules():
if name + ".mask" in diff_dict:
coeff = diff_dict[name + ".coeff"].to(device)
mask = diff_dict[name + ".mask"].to(device)

# setattr(module, "mask", mask)
# setattr(module, "coeff", coeff)
weight = (unpack(mask)*2-1) * coeff
weight_fp16 = decomposition(ori_diff[name + ".weight"].to(torch.float32), dim=64).to(torch.bfloat16)
# import pdb; pdb.set_trace()

module.weight.add_(weight_fp16.to(module.weight.dtype) + weight.T.to(module.weight.dtype))
elif name + ".weight" in diff_dict:
module.weight = nn.Parameter(diff_dict[name + ".weight"].to(device).to(module.weight.dtype))

# if "mlp" in name:
# import pdb; pdb.set_trace()

elif name + '.A' in diff_dict:
A = diff_dict[name + '.A'].to(device)
B = diff_dict[name + '.B'].to(device)

mask = (A @ B).T
module.weight.add_(mask.to(module.weight.dtype))

model.config.vocab_size = model.lm_head.weight.size(0)

def decomposition(masked_input_tensor,dim=None,st=None,ed=None,name=None):
U , S , V = torch.svd(masked_input_tensor.to(torch.float32))

if dim is not None:
U , S , V = U[:, :dim],S[:dim] ,V[:, :dim]

if st is not None and ed is not None:
U , S , V = U[:, st:ed],S[st:ed] ,V[:, st:ed]

return U, S, V

def save_full_model(base_model_name, finetuned_model_name, diff_dir, save_dir, device,layers=None,ori_diff=None):
base_model = get_model(base_model_name, device)
tokenizer = get_tokenizer(finetuned_model_name)

finetuned_model = get_model(finetuned_model_name, device)
# params = {}

load_diff(base_model, diff_dir,ori_diff=ori_diff)

base_model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

del base_model

11 changes: 10 additions & 1 deletion bitdelta/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
base_model = get_model(args.base_model, args.base_model_device, args.base_model_memory_map)
finetuned_model = get_model(args.finetuned_model, args.finetuned_model_device, args.finetuned_model_memory_map)

def original_diff(base_model, finetuned_model):
origin_diff = {}
for k, v in finetuned_model.named_parameters():
if "mlp" in k or "self_attn" in k:
origin_diff[k] = v.detach().cpu() - base_model.get_submodule(k.replace('.weight',"")).weight.detach().cpu()
return origin_diff

# get corr/stddev stats
if args.debug:
print(f"finding corr/stddev stats...")
Expand Down Expand Up @@ -94,6 +101,8 @@
with open(os.path.join(args.save_dir, f"train_loss_{args.num_groups}.json"), "w") as f:
json.dump(train_loss_list, f)

ori_diff = original_diff(base_model, finetuned_model)

# # save trained delta
save_diff(finetuned_compressed_model, os.path.join(args.save_dir, "diff.pt"),layers=args.layers)

Expand All @@ -102,6 +111,6 @@

if args.save_full_model:
print("saving uncalibrated model")
save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff_untrained.pt"), os.path.join(args.save_dir, f"uncalibrated_model"), device="cpu",layers=args.layers)
save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff_untrained.pt"), os.path.join(args.save_dir, f"uncalibrated_model"), device="cpu",layers=args.layers,ori_diff=ori_diff)
# print("saving calibrated model")
# save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff.pt"), os.path.join(args.save_dir, "calibrated_model"), device="cpu")
14 changes: 4 additions & 10 deletions bitdelta/train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

import torch.nn.functional as F
from bitdelta.diff import compress_diff, save_diff, save_full_model
from bitdelta.diff2 import compress_diff, save_diff, save_full_model
from bitdelta.misc import find_corr_stddev

from bitdelta.utils import get_model, parse_args, get_tokenizer
Expand All @@ -17,7 +17,7 @@
# create save_dir if it doesn't exist
os.makedirs(args.save_dir, exist_ok=True)

tokenizer = get_tokenizer(args.base_model)
tokenizer = get_tokenizer(args.finetuned_model)

with torch.no_grad():
base_model = get_model(args.base_model, args.base_model_device, args.base_model_memory_map)
Expand All @@ -26,12 +26,6 @@
finetuned_compressed_model = get_model(args.finetuned_model, args.finetuned_compressed_model_device, args.finetuned_compressed_model_memory_map)

print(f"compressing diff...")
compress_diff(base_model, finetuned_model, finetuned_compressed_model)
compress_diff(base_model, finetuned_model, finetuned_compressed_model,args.save_dir)

# save untrained delta
save_diff(finetuned_compressed_model, os.path.join(args.save_dir, "diff_untrained.pt"))


if args.save_full_model:
print("saving uncalibrated model")
save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff_untrained.pt"), os.path.join(args.save_dir, "uncalibrated_model"), device="cpu")
tokenizer.save_pretrained(args.save_dir)
4 changes: 2 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
MODEL_SAVE_DIR=save/
MODEL_SAVE_DIR=save/uncalibrated_model_0

mkdir -p $MODEL_SAVE_DIR

CUDA_VISIBLE_DEVICES=6,7 python \
bitdelta/train.py \
bitdelta/train2.py \
--base_model /data/public/opensource_models/meta-llama/Llama-2-7b-hf/ \
--finetuned_model /data/public/opensource_models/WizardLM/WizardMath-7B-V1.0/ \
--save_dir $MODEL_SAVE_DIR \
Expand Down
Loading