Skip to content

Commit 9455734

Browse files
committed
Update notebook with custom adapter
1 parent b2a2db8 commit 9455734

1 file changed

Lines changed: 165 additions & 3 deletions

File tree

docs/tutorials/nn/neural-network-adapters.ipynb

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,175 @@
302302
"!python train.py train_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer \"!ref <valid_scorer>\" --enable_add_reverb=False --enable_add_noise=False #To speed up"
303303
]
304304
},
305+
{
306+
"cell_type": "markdown",
307+
"id": "f960be7a-6edb-47c1-bd8d-c7c0d486c81d",
308+
"metadata": {},
309+
"source": [
310+
"## Custom adapter\n",
311+
"\n",
312+
"We designed this so that you could replace the SpeechBrain adapter with a `peft` adapter:\n",
313+
"\n",
314+
"```diff\n",
315+
"model: !new:speechbrain.nnet.adapters.AdaptedModel\n",
316+
" model_to_adapt: !ref <model_pretrained>\n",
317+
"- adapter_class: !name:speechbrain.nnet.adapters.LoRA\n",
318+
"+ adapter_class: !name:peft.tuners.lora.layer.Linear\n",
319+
" adapter_kwargs:\n",
320+
"- rank: 16\n",
321+
"+ r: 16\n",
322+
"+ adapter_name: lora\n",
323+
"```\n",
324+
"\n",
325+
"But this trains exactly the same thing as before, so no need for us to go through the whole thing. Perhaps more interesting is designing a custom adapter:"
326+
]
327+
},
305328
{
306329
"cell_type": "code",
307-
"execution_count": null,
308-
"id": "f3ed4d04-c6db-4cb1-a086-d558564fe402",
330+
"execution_count": 12,
331+
"id": "f9682f70-489a-4a1d-b8c6-1c73d98a824d",
332+
"metadata": {},
333+
"outputs": [
334+
{
335+
"name": "stdout",
336+
"output_type": "stream",
337+
"text": [
338+
"Writing conv_lora.py\n"
339+
]
340+
}
341+
],
342+
"source": [
343+
"%%file conv_lora.py\n",
344+
"\n",
345+
"import torch\n",
346+
"\n",
347+
"class Conv2dLoRA(torch.nn.Module):\n",
348+
" def __init__(self, target_module, kernel_size=3, stride=2, channels=16):\n",
349+
" super().__init__()\n",
350+
"\n",
351+
" # Disable gradient for pretrained module\n",
352+
" self.pretrained_module = target_module\n",
353+
" for param in self.pretrained_module.parameters():\n",
354+
" param.requires_grad = False\n",
355+
" device = target_module.weight.device\n",
356+
"\n",
357+
" self.adapter_down_conv = torch.nn.Conv2D(\n",
358+
" in_channels=1, out_channels=channels, padding=\"same\", stride=2, bias=False, device=device\n",
359+
" )\n",
360+
" self.adapter_up_scale = torch.nn.Upscale(scale_factor=2)\n",
361+
" self.adapter_up_conv = torch.nn.Conv2D(\n",
362+
" in_channels=channels, out_channels=1, padding=\"same\", bias=False, device=device\n",
363+
" )\n",
364+
"\n",
365+
"\n",
366+
" def forward(self, x: torch.Tensor):\n",
367+
" \"\"\"Applies the LoRA Adapter.\n",
368+
"\n",
369+
" Arguments\n",
370+
" ---------\n",
371+
" x: torch.Tensor\n",
372+
" Input tensor to the adapter module.\n",
373+
"\n",
374+
" Returns\n",
375+
" -------\n",
376+
" The linear outputs\n",
377+
" \"\"\"\n",
378+
" x_pretrained = self.pretrained_module(x)\n",
379+
" x_conv_lora = self.adapter_up_conv(self.adapter_up_scale(self.adapter_down_conv(x)))\n",
380+
"\n",
381+
" return x_pretrained + x_conv_lora * self.scaling"
382+
]
383+
},
384+
{
385+
"cell_type": "code",
386+
"execution_count": 13,
387+
"id": "c2e702a9-c07d-4a76-94bc-847b8f890579",
309388
"metadata": {},
310389
"outputs": [],
311-
"source": []
390+
"source": [
391+
"# Change the adapter out\n",
392+
"train_yaml = train_yaml.replace(\"output_folder: !ref results/crdnn_lora/<seed>\", \"output_folder: !ref results/crdnn_conv_lora/<seed>\")\n",
393+
"train_yaml.replace(\"\"\"\n",
394+
"model: !new:speechbrain.nnet.adapters.AdaptedModel\n",
395+
" model_to_adapt: !ref <model_pretrained>\n",
396+
" adapter_class: !name:speechbrain.nnet.adapters.LoRA\n",
397+
" adapter_kwargs:\n",
398+
" rank: 16\n",
399+
"\"\"\", \"\"\"\n",
400+
"model: !new:speechbrain.nnet.adapters.AdaptedModel\n",
401+
" model_to_adapt: !ref <model_pretrained>\n",
402+
" adapter_class: !name:conv_lora.Conv2dLoRA\n",
403+
" adapter_kwargs:\n",
404+
" kernel_size: 3\n",
405+
" stride: 2\n",
406+
" channels: 16\n",
407+
"\"\"\")\n",
408+
"\n",
409+
"with open(\"train_conv_lora.yaml\", \"w\") as f:\n",
410+
" f.write(train_yaml)"
411+
]
412+
},
413+
{
414+
"cell_type": "code",
415+
"execution_count": 14,
416+
"id": "56aefd64-1325-4891-a9a4-1c4e85691b96",
417+
"metadata": {},
418+
"outputs": [
419+
{
420+
"name": "stdout",
421+
"output_type": "stream",
422+
"text": [
423+
"WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures\n",
424+
"/home/competerscience/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n",
425+
" wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)\n",
426+
"speechbrain.core - Beginning experiment!\n",
427+
"speechbrain.core - Experiment folder: results/crdnn_conv_lora/4324\n",
428+
"mini_librispeech_prepare - Preparation completed in previous run, skipping.\n",
429+
"../data/noise/data.zip exists. Skipping download\n",
430+
"../data/rir/data.zip exists. Skipping download\n",
431+
"speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model\n",
432+
"/home/competerscience/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
433+
" state_dict = torch.load(path, map_location=device)\n",
434+
"speechbrain.core - Exception:\n",
435+
"Traceback (most recent call last):\n",
436+
" File \"/home/competerscience/Documents/Repositories/speechbrain/docs/tutorials/nn/speechbrain/templates/speech_recognition/ASR/train.py\", line 461, in <module>\n",
437+
" hparams[\"pretrainer\"].load_collected()\n",
438+
" File \"/home/competerscience/Documents/Repositories/speechbrain/speechbrain/utils/parameter_transfer.py\", line 295, in load_collected\n",
439+
" self._call_load_hooks(paramfiles)\n",
440+
" File \"/home/competerscience/Documents/Repositories/speechbrain/speechbrain/utils/parameter_transfer.py\", line 312, in _call_load_hooks\n",
441+
" default_hook(obj, loadpath)\n",
442+
" File \"/home/competerscience/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py\", line 240, in torch_parameter_transfer\n",
443+
" state_dict = torch_patched_state_dict_load(path, device)\n",
444+
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
445+
" File \"/home/competerscience/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py\", line 199, in torch_patched_state_dict_load\n",
446+
" state_dict = torch.load(path, map_location=device)\n",
447+
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
448+
" File \"/home/competerscience/Documents/uvenv/lib/python3.12/site-packages/torch/serialization.py\", line 1065, in load\n",
449+
" with _open_file_like(f, 'rb') as opened_file:\n",
450+
" ^^^^^^^^^^^^^^^^^^^^^^^^\n",
451+
" File \"/home/competerscience/Documents/uvenv/lib/python3.12/site-packages/torch/serialization.py\", line 468, in _open_file_like\n",
452+
" return _open_file(name_or_buffer, mode)\n",
453+
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
454+
" File \"/home/competerscience/Documents/uvenv/lib/python3.12/site-packages/torch/serialization.py\", line 449, in __init__\n",
455+
" super().__init__(open(name, mode))\n",
456+
" ^^^^^^^^^^^^^^^^\n",
457+
"FileNotFoundError: [Errno 2] No such file or directory: '/home/competerscience/Documents/Repositories/speechbrain/docs/tutorials/nn/speechbrain/templates/speech_recognition/ASR/results/CRDNN_BPE_960h_LM/2602/save/lm.ckpt'\n"
458+
]
459+
}
460+
],
461+
"source": [
462+
"!python train.py train_conv_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer \"!ref <valid_scorer>\" --enable_add_reverb=False --enable_add_noise=False #To speed up"
463+
]
464+
},
465+
{
466+
"cell_type": "markdown",
467+
"id": "21ef247c-3022-4b65-8cd0-86d01a618b79",
468+
"metadata": {},
469+
"source": [
470+
"## Conclusion\n",
471+
"\n",
472+
"That's it, thanks for following along! Go forth and make cool adapters."
473+
]
312474
}
313475
],
314476
"metadata": {

0 commit comments

Comments
 (0)