|
| 1 | +output_folder: results # Main directory to store experiment results |
| 2 | +run_name: "RUN_NAME" # Will be updated with a unique name at runtime |
| 3 | + |
| 4 | +save_dir: !ref <output_folder>/<run_name>/checkpoints # Directory to save checkpoints |
| 5 | +enhanced_dir: !ref <output_folder>/<run_name>/enhanced_training # Directory to store waveforms at validation during training |
| 6 | + |
| 7 | +data_dir: !PLACEHOLDER # Root dir for the dataset |
| 8 | +train_annotation: !ref <data_dir>/train.json # JSON file listing training samples |
| 9 | +valid_annotation: !ref <data_dir>/valid.json # JSON file listing validation samples |
| 10 | +test_annotation: !ref <data_dir>/test.json # JSON file listing test samples |
| 11 | + |
| 12 | +skip_prep: False # If True, skip data preparation steps |
| 13 | +segment_frames: 256 # Number of STFT frames fed into the model. Has to align with what the model ‘wants’ to see due to u net architecture |
| 14 | +random_crop: True # Whether to crop segments randomly from longer waveforms in training |
| 15 | +random_crop_valid: False # Whether to crop segments randomly from longer waveforms in validation |
| 16 | +random_crop_test: False # Whether to crop segments randomly from longer waveforms in testing |
| 17 | + |
| 18 | +normalize: noisy # Waveforms are normalized with respect to ... (noisy / clean / not) |
| 19 | +sample_rate: 16000 # Sampling rate (in Hz) for audio data |
| 20 | +batch_size: 8 # Batch size for the training set |
| 21 | +number_of_epochs: 160 # Total epochs to train |
| 22 | +num_to_keep: 2 # Numbers of checkpoints to keep |
| 23 | +lr: 0.0001 # Learning rate |
| 24 | +sorting: ascending # Sorting strategy for data loading (e.g., ascending, descending) |
| 25 | + |
| 26 | +n_fft: 510 # FFT size for STFT |
| 27 | +hop_length: 128 # Hop length (stride) for STFT |
| 28 | +window_type: hann # Type of window function for STFT |
| 29 | + |
| 30 | +transform_type: exponent # Type of spectral transform (log, exponent, none) |
| 31 | +spec_factor: 0.15 # Factor to scale the transformed spectrogram |
| 32 | +spec_abs_exponent: 0.5 # Exponent to apply to spectrogram magnitude if needed |
| 33 | + |
| 34 | +train_dataloader_opts: |
| 35 | + batch_size: !ref <batch_size> |
| 36 | + shuffle: True # Shuffle training data each epoch |
| 37 | + |
| 38 | +valid_dataloader_opts: |
| 39 | + batch_size: 1 # Validation batch size |
| 40 | + |
| 41 | +test_dataloader_opts: |
| 42 | + batch_size: 1 # Test batch size |
| 43 | + |
| 44 | +sampling: |
| 45 | + sampler_type: pc |
| 46 | + predictor: reverse_diffusion |
| 47 | + corrector: ald |
| 48 | + N: 30 |
| 49 | + corrector_steps: 1 |
| 50 | + snr: 0.5 |
| 51 | + |
| 52 | +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter |
| 53 | + limit: !ref <number_of_epochs> # Sets the upper bound on training epochs |
| 54 | + |
| 55 | +modules: |
| 56 | + score_model: !new:speechbrain.integrations.models.sgmse_plus.ScoreModel |
| 57 | + backbone: ncsnpp_v2 # Name of the backbone neural network architecture |
| 58 | + sde: ouve # Which SDE to use (Ornstein-Uhlenbeck VE SDE) |
| 59 | + theta: 1.5 # Stiffness parameter for the OU SDE |
| 60 | + sigma_min: 0.05 # Minimum sigma value for OU SDE |
| 61 | + sigma_max: 0.5 # Maximum sigma value for OU SDE |
| 62 | + lr: !ref <lr> # Learning rate for the model |
| 63 | + ema_decay: 0.999 # Decay factor for EMA of model parameters |
| 64 | + t_eps: 0.03 # Min time-step to avoid zero in continuous diffusion |
| 65 | + num_eval_files: 5 # Number of files to process for evaluation |
| 66 | + loss_type: score_matching # Which loss approach to use (score matching, etc.) |
| 67 | + loss_weighting: sigma^2 # Weighting in the loss function |
| 68 | + network_scaling: 1/t # Scaling strategy (if any) for network outputs |
| 69 | + c_in: "1" # Input scaling scheme |
| 70 | + c_out: "1" # Output scaling scheme |
| 71 | + c_skip: "0" # Skip connection scaling scheme |
| 72 | + sigma_data: 0.1 # Data STD for EDM-based parameterizations |
| 73 | + l1_weight: 0.001 # Weight factor for L1 (time-domain) loss |
| 74 | + pesq_weight: 0.0 # Weight factor for PESQ-based loss (0 = disabled) |
| 75 | + N: !ref <sampling[N]> # Sampler steps |
| 76 | + corrector_steps: !ref <sampling[corrector_steps]> # Corrector updates per step |
| 77 | + sampler_type: !ref <sampling[sampler_type]> # SDE sampler type |
| 78 | + snr: !ref <sampling[snr]> # SNR for sampler |
| 79 | + sr: !ref <sample_rate> # Sample rate for model references |
| 80 | + |
| 81 | +opt_class: !name:torch.optim.Adam |
| 82 | + lr: !ref <lr> # LR used in the Adam optimizer |
| 83 | + |
| 84 | +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer |
| 85 | + checkpoints_dir: !ref <save_dir> # Directory to store checkpoint files |
| 86 | + recoverables: |
| 87 | + score_model: !ref <modules[score_model]> # Model parameters to be saved |
| 88 | + counter: !ref <epoch_counter> # Epoch counter to be saved |
0 commit comments