2424from executorch .examples .models .llama2 .source_transformation .sdpa import (
2525 replace_sdpa_with_custom_op ,
2626)
27+ from executorch .examples .models .llava .model import LlavaModel
2728from executorch .exir import EdgeCompileConfig
2829from executorch .exir .program ._program import _to_edge_transform_and_lower
2930
3031from executorch .extension .llm .export .builder import DType , LLMEdgeManager
31- from model import LlavaModel
3232from torch .ao .quantization .quantizer .xnnpack_quantizer import (
3333 get_symmetric_quantization_config ,
3434 XNNPACKQuantizer ,
@@ -85,7 +85,7 @@ def forward(self, input_pos, embeddings):
8585 ["-X" , "-qmode" , "8da4w" , "--group_size" , "128" , "--embedding-quantize" , "4,32" ]
8686 )
8787 quant_transform = get_quant_weight_transform (args , dtype_override , False )
88- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
88+ _ , quantizers , _ = get_quantizer_and_quant_params (args )
8989 source_transforms = []
9090 if llava .use_sdpa_with_kv_cache_op :
9191 source_transforms .append (replace_sdpa_with_custom_op )
@@ -149,15 +149,7 @@ def forward(self, images):
149149
150150
151151def export_token_embedding (llava , prompt ):
152- embed = torch .nn .Embedding (
153- llava .model_ .config .vocab_size ,
154- llava .model_ .config .hidden_size ,
155- llava .model_ .config .pad_token_id ,
156- )
157- embed .load_state_dict (
158- llava .model_ .get_model ().embed_tokens .state_dict (), strict = True , assign = True
159- )
160- embed = embed .to (torch .float32 )
152+ embed = llava .embed_tokens
161153 token_dim_1 = Dim ("token_dim_1" , min = 2 , max = 3518 )
162154 dynamic_shapes = [{1 : token_dim_1 }]
163155 with torch .no_grad ():
@@ -167,24 +159,7 @@ def export_token_embedding(llava, prompt):
167159 return token_embedding_ep
168160
169161
170- def main ():
171- parser = ArgumentParser ()
172- parser .add_argument (
173- "--use-sdpa-with-kv-cache" ,
174- default = True ,
175- action = BooleanOptionalAction ,
176- help = "Use sdpa_with_kv_cache custom op in LLava text model." ,
177- )
178- parser .add_argument (
179- "--pte-name" ,
180- default = "llava_combined_xnnpack.pte" ,
181- help = "Name of the exported ExecuTorch program." ,
182- )
183- args = parser .parse_args ()
184- logging .info (
185- f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: { args .use_sdpa_with_kv_cache } "
186- )
187- llava_model = LlavaModel (use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache )
162+ def export_all (llava_model : LlavaModel ):
188163 llava = llava_model .get_eager_model ()
189164
190165 (
@@ -226,6 +201,29 @@ def main():
226201 )
227202
228203 executorch_program = lowered_and_edge .to_executorch ()
204+ return executorch_program
205+
206+
207+ def main ():
208+ parser = ArgumentParser ()
209+ parser .add_argument (
210+ "--use-sdpa-with-kv-cache" ,
211+ default = True ,
212+ action = BooleanOptionalAction ,
213+ help = "Use sdpa_with_kv_cache custom op in LLava text model." ,
214+ )
215+ parser .add_argument (
216+ "--pte-name" ,
217+ default = "llava_combined_xnnpack.pte" ,
218+ help = "Name of the exported ExecuTorch program." ,
219+ )
220+ args = parser .parse_args ()
221+ logging .info (
222+ f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: { args .use_sdpa_with_kv_cache } "
223+ )
224+ llava_model = LlavaModel (use_sdpa_with_kv_cache_op = args .use_sdpa_with_kv_cache )
225+
226+ executorch_program = export_all (llava_model )
229227
230228 with open (args .pte_name , "wb" ) as f :
231229 executorch_program .write_to_file (f )
0 commit comments