Skip to content
Open
Prev Previous commit
Next Next commit
load llava
  • Loading branch information
pingbowen23 committed Mar 31, 2024
commit 53a0fc7a00cd5f4a159952c6005e1589b32b15a0
8 changes: 6 additions & 2 deletions bitdelta/diff2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def copy_nonzero_values(A, B):
A[mask] = B[mask]
return A

def compress_diff(base_model, finetuned_model, finetuned_compressed_model,save_dir,args):
def compress_diff(base_model, finetuned_model, save_dir,args):
def compress_submodule(name, subname, module, submodule):
target_device = submodule.weight.device

Expand All @@ -121,7 +121,11 @@ def compress_submodule(name, subname, module, submodule):
setattr(module, subname, compressed)

# TODO: 根据thresh 选择压缩比例
for name, module in finetuned_compressed_model.named_modules():
for name, module in finetuned_model.named_modules():

if "vision" in name:
continue

if "self_attn" in name or "mlp" in name:
for subname, submodule in module.named_children():

Expand Down
13 changes: 8 additions & 5 deletions bitdelta/train2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bitdelta.diff2 import compress_diff, save_diff, save_full_model
from bitdelta.misc import find_corr_stddev

from bitdelta.utils import get_model, parse_args, get_tokenizer
from bitdelta.utils import get_model, parse_args, get_tokenizer,load_llava
from tqdm import tqdm
from bitdelta.data import get_dataset, get_dataloader

Expand All @@ -21,11 +21,14 @@

with torch.no_grad():
base_model = get_model(args.base_model, args.base_model_device, args.base_model_memory_map)
finetuned_model = get_model(args.finetuned_model, args.finetuned_model_device, args.finetuned_model_memory_map)
if "llava" in args.finetuned_model.lower():
finetuned_model = load_llava(args.finetuned_model,device="cuda" if torch.cuda.is_available() else "cpu")
else:
finetuned_model = get_model(args.finetuned_model, args.finetuned_model_device, args.finetuned_model_memory_map)

finetuned_compressed_model = get_model(args.finetuned_model, args.finetuned_compressed_model_device, args.finetuned_compressed_model_memory_map)

import pdb;pdb.set_trace()
print(f"compressing diff...")
compress_diff(base_model, finetuned_model, finetuned_compressed_model,args.save_dir,args)
compress_diff(base_model, finetuned_model, args.save_dir,args)

tokenizer.save_pretrained(args.save_dir)
tokenizer.save_pretrained(args.save_dir)
40 changes: 39 additions & 1 deletion bitdelta/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,46 @@
import argparse
import transformers
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM,AutoTokenizer
from accelerate import infer_auto_device_map, init_empty_weights
import os
from llava.model import *
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

def load_llava(path,device):
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(
path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
).to(device)


image_processor = None

if 'llava' in path.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model(device_map=device)
if device != 'auto':
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048

return model


def parse_args():
parser = argparse.ArgumentParser(description="BitDelta")
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
MODEL_SAVE_DIR=./../save/uncalibrated_llava
MODEL_SAVE_DIR=./../save/llama_7b_chat_attn_mlp_outlier_0.2_0.1/

mkdir -p $MODEL_SAVE_DIR

Expand Down