|
| 1 | +################################################### |
| 2 | +# Classification of imagined movements of BNCI2014004 MOABB dataset using DeepConvNet. |
| 3 | +# DeepConvNet from . |
| 4 | +# BNCI2014004 from . |
| 5 | +# |
| 6 | +# Author |
| 7 | +# ------ |
| 8 | +# Davide Borra, 2022 |
| 9 | +# Mirco Ravanelli, 2022 |
| 10 | +# Francesco Paissan, 2022 |
| 11 | +################################################### |
| 12 | +seed: 1234 |
| 13 | +__set_torchseed: !apply:torch.manual_seed [!ref <seed>] |
| 14 | + |
| 15 | +# DIRECTORIES |
| 16 | +data_folder: !PLACEHOLDER #'/path/to/dataset'. The dataset will be automatically downloaded in this folder |
| 17 | +cached_data_folder: !PLACEHOLDER #'path/to/pickled/dataset' |
| 18 | +output_folder: !ref results/MotorImagery/BNCI2014004/DeepConvNet/<seed> |
| 19 | + |
| 20 | + |
| 21 | +# DATASET HPARS |
| 22 | +# Defining the MOABB dataset. |
| 23 | +dataset: !new:moabb.datasets.BNCI2014004 |
| 24 | +save_prepared_dataset: True # set to True if you want to save the prepared dataset as a pkl file to load and use afterwards |
| 25 | + |
| 26 | +data_iterator_name: !PLACEHOLDER |
| 27 | +target_subject_idx: !PLACEHOLDER |
| 28 | +target_session_idx: !PLACEHOLDER |
| 29 | +events_to_load: null # all events will be loaded |
| 30 | +original_sample_rate: 250 # Original sampling rate provided by dataset authors |
| 31 | +sample_rate: 125 # Target sampling rate (Hz) |
| 32 | + |
| 33 | +# band-pass filtering cut-off frequencies |
| 34 | +fmin: 0.14 # @orion_step1: --fmin~"uniform(1, 5, precision=2)" |
| 35 | +fmax: 45.7 # @orion_step1: --fmax~"uniform(30.0, 50.0, precision=3)" |
| 36 | + |
| 37 | +n_classes: 2 |
| 38 | + |
| 39 | +# tmin, tmax respect to stimulus onset that define the interval attribute of the dataset class |
| 40 | +# trial begins (0 s), cue (2 s, 1.25 s long); each trial is 6 s long |
| 41 | +# dataset interval starts from 2 |
| 42 | +# -->tmin tmax are referred to this start value (e.g., tmin=0.5 corresponds to 2.5 s) |
| 43 | +tmin: 0. |
| 44 | +tmax: 3.9 # @orion_step1: --tmax~"uniform(1.0, 4.0, precision=2)" |
| 45 | +# number of steps used when selecting adjacent channels from a seed channel (default at Cz) |
| 46 | +n_steps_channel_selection: 1 |
| 47 | + |
| 48 | +T: !apply:math.ceil |
| 49 | + - !ref <sample_rate> * (<tmax> - <tmin>) |
| 50 | +C: 3 |
| 51 | +n_train_examples: 100 # it will be replaced in the train script |
| 52 | +# We here specify how to perfom test: |
| 53 | +# - If test_with: 'last' we perform test with the latest model. |
| 54 | +# - if test_with: 'best, we perform test with the best model (according to the metric specified in test_key) |
| 55 | +# The variable avg_models can be used to average the parameters of the last (or best) N saved models before testing. |
| 56 | +# This can have a regularization effect. If avg_models: 1, the last (or best) model is used directly. |
| 57 | +test_with: 'last' # 'last' or 'best' |
| 58 | +test_key: "acc" # Possible opts: "loss", "f1", "auc", "acc" |
| 59 | + |
| 60 | +# checkpoints to average |
| 61 | +avg_models: 15 # @orion_step1: --avg_models~"uniform(1, 15,discrete=True)" |
| 62 | + |
| 63 | +f1: !name:sklearn.metrics.f1_score |
| 64 | + average: 'macro' |
| 65 | +acc: !name:sklearn.metrics.balanced_accuracy_score |
| 66 | +cm: !name:sklearn.metrics.confusion_matrix |
| 67 | +metrics: |
| 68 | + f1: !ref <f1> |
| 69 | + acc: !ref <acc> |
| 70 | + cm: !ref <cm> |
| 71 | + |
| 72 | +# TRAINING HPARS |
| 73 | +number_of_epochs: 940 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)" |
| 74 | +lr: 0.0001 # @orion_step1: --lr~"choices([0.01, 0.005, 0.001, 0.0005, 0.0001])" |
| 75 | + |
| 76 | +# Learning rate scheduling (cyclic learning rate is used here) |
| 77 | +max_lr: !ref <lr> # Upper bound of the cycle (max value of the lr) |
| 78 | +base_lr: 0.00000001 # Lower bound in the cycle (min value of the lr) |
| 79 | + |
| 80 | +step_size_multiplier: 5 #from 2 to 8 |
| 81 | +step_size: !apply:round |
| 82 | + - !ref <step_size_multiplier> * <n_train_examples> / <batch_size> |
| 83 | + |
| 84 | +label_smoothing: 0.0 |
| 85 | + |
| 86 | +loss: !name:speechbrain.nnet.losses.nll_loss |
| 87 | + label_smoothing: !ref <label_smoothing> |
| 88 | + |
| 89 | +optimizer: !name:torch.optim.Adam |
| 90 | + lr: !ref <lr> |
| 91 | +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter # epoch counter |
| 92 | + limit: !ref <number_of_epochs> |
| 93 | +batch_size_exponent: 4 # @orion_step1: --batch_size_exponent~"uniform(4, 6,discrete=True)" |
| 94 | +batch_size: !ref 2 ** <batch_size_exponent> |
| 95 | +valid_ratio: 0.2 |
| 96 | + |
| 97 | +# DATA AUGMENTATION |
| 98 | +# cutcat (disabled when min_num_segments=max_num_segments=1) |
| 99 | +max_num_segments: 3 # @orion_step2: --max_num_segments~"uniform(2, 6, discrete=True)" |
| 100 | +cutcat: !new:speechbrain.processing.speech_augmentation.CutCat |
| 101 | + min_num_segments: 2 |
| 102 | + max_num_segments: !ref <max_num_segments> |
| 103 | + |
| 104 | +# random amplitude gain between 0.5-1.5 uV (disabled when amp_delta=0.) |
| 105 | +amp_delta: 0.3492 # @orion_step2: --amp_delta~"uniform(0.0, 0.5)" |
| 106 | +rand_amp: !new:speechbrain.processing.speech_augmentation.RandAmp |
| 107 | + amp_low: !ref 1 - <amp_delta> |
| 108 | + amp_high: !ref 1 + <amp_delta> |
| 109 | + |
| 110 | +# random shifts between -300 ms to 300 ms (disabled when shift_delta=0.) |
| 111 | +shift_delta_: 3 #@orion_step2: --shift_delta_~"uniform(0, 25, discrete=True)" |
| 112 | +shift_delta: !ref 1e-2 * <shift_delta_> # 0.250 # 0.-0.25 with steps of 0.01 |
| 113 | +min_shift: !apply:math.floor |
| 114 | + - !ref 0 - <sample_rate> * <shift_delta> |
| 115 | +max_shift: !apply:math.floor |
| 116 | + - !ref 0 + <sample_rate> * <shift_delta> |
| 117 | +time_shift: !new:speechbrain.processing.speech_augmentation.RandomShift |
| 118 | + min_shift: !ref <min_shift> |
| 119 | + max_shift: !ref <max_shift> |
| 120 | + dim: 1 |
| 121 | + |
| 122 | +# injection of gaussian white noise |
| 123 | +snr_white_low: 9.1 # @orion_step2: --snr_white_low~"uniform(0.0, 15, precision=2)" |
| 124 | +snr_white_delta: 11.3 # @orion_step2: --snr_white_delta~"uniform(5.0, 20.0, precision=3)" |
| 125 | +snr_white_high: !ref <snr_white_low> + <snr_white_delta> |
| 126 | +add_noise_white: !new:speechbrain.processing.speech_augmentation.AddNoise |
| 127 | + snr_low: !ref <snr_white_low> |
| 128 | + snr_high: !ref <snr_white_high> |
| 129 | + |
| 130 | +repeat_augment: 1 # @orion_step1: --repeat_augment 0 |
| 131 | +augment: !new:speechbrain.processing.augmentation.Augmenter |
| 132 | + parallel_augment: True |
| 133 | + concat_original: True |
| 134 | + parallel_augment_fixed_bs: True |
| 135 | + repeat_augment: !ref <repeat_augment> |
| 136 | + shuffle_augmentations: True |
| 137 | + min_augmentations: 4 |
| 138 | + max_augmentations: 4 |
| 139 | + cutcat: !ref <cutcat> |
| 140 | + rand_amp: !ref <rand_amp> |
| 141 | + time_shift: !ref <time_shift> |
| 142 | + augment_noise: !ref <add_noise_white> |
| 143 | + |
| 144 | +# DATA NORMALIZATION |
| 145 | +dims_to_normalize: 1 # 1 (time) or 2 (EEG channels) |
| 146 | +normalize: !name:speechbrain.processing.signal_processing.mean_std_norm |
| 147 | + dims: !ref <dims_to_normalize> |
| 148 | + |
| 149 | +# MODEL: DEEPCONVNET |
| 150 | +input_shape: [null, !ref <T>, !ref <C>, null] |
| 151 | +cnn_temporal_kernels: 25 # @orion_step1: --cnn_temporal_kernels~"uniform(4, 64,discrete=True)" |
| 152 | +cnn_temporal_kernelsize: 10 # @orion_step1: --cnn_temporal_kernelsize~"uniform(5, 62,discrete=True)" |
| 153 | +cnn_spatial_kernels: 25 # @orion_step1: --cnn_spatial_kernels~"uniform(4, 64,discrete=True)" |
| 154 | +cnn_spatial_pool: 1 # disabling pooling as we use 125/128 Hz data (and not 250/256 Hz data) |
| 155 | + |
| 156 | +cnn_temporal_block_multiplier: 2 # @orion_step1: --cnn_temporal_block_multiplier~"uniform(1, 3,discrete=True)" |
| 157 | +cnn_temporal_block_kernel0: 50 # @orion_step1: --cnn_temporal_block_kernel0~"uniform(4, 64,discrete=True)" |
| 158 | +cnn_temporal_block_kernel1: !ref <cnn_temporal_block_kernel0> * <cnn_temporal_block_multiplier> |
| 159 | +cnn_temporal_block_kernel2: !ref <cnn_temporal_block_kernel1> * <cnn_temporal_block_multiplier> |
| 160 | +cnn_temporal_block_kernelsize0: 10 # @orion_step1: --cnn_temporal_block_kernelsize0~"uniform(5, 20,discrete=True)" |
| 161 | +cnn_temporal_block_kernelsize1: !ref <cnn_temporal_block_kernelsize0> |
| 162 | +cnn_temporal_block_kernelsize2: !ref <cnn_temporal_block_kernelsize0> |
| 163 | +cnn_temporal_block_pool: 1 # disabling pooling as we use 125/128 Hz data (and not 250/256 Hz data) |
| 164 | +dropout: 0.1748 # @orion_step1: --dropout~"uniform(0.0, 0.5)" |
| 165 | +activation_type: 'elu' |
| 166 | + |
| 167 | +model: !new:models.DeepConvNet.DeepConvNet |
| 168 | + input_shape: !ref <input_shape> |
| 169 | + cnn_temporal_kernels: !ref <cnn_temporal_kernels> |
| 170 | + cnn_temporal_kernelsize: [!ref <cnn_temporal_kernelsize>, 1] |
| 171 | + cnn_spatial_kernels: !ref <cnn_spatial_kernels> |
| 172 | + cnn_spatial_pool: [!ref <cnn_spatial_pool>, 1] |
| 173 | + cnn_temporal_block_kernel0: !ref <cnn_temporal_block_kernel0> |
| 174 | + cnn_temporal_block_kernel1: !ref <cnn_temporal_block_kernel1> |
| 175 | + cnn_temporal_block_kernel2: !ref <cnn_temporal_block_kernel2> |
| 176 | + cnn_temporal_block_kernelsize0: [!ref <cnn_temporal_block_kernelsize0>, 1] |
| 177 | + cnn_temporal_block_kernelsize1: [!ref <cnn_temporal_block_kernelsize1>, 1] |
| 178 | + cnn_temporal_block_kernelsize2: [!ref <cnn_temporal_block_kernelsize2>, 1] |
| 179 | + cnn_temporal_block_pool: [!ref <cnn_temporal_block_pool>, 1] |
| 180 | + activation_type: !ref <activation_type> |
| 181 | + dropout: !ref <dropout> |
| 182 | + dense_n_neurons: !ref <n_classes> |
| 183 | + |
| 184 | +lr_annealing: !new:speechbrain.nnet.schedulers.CyclicLRScheduler |
| 185 | + base_lr: !ref <base_lr> |
| 186 | + max_lr: !ref <max_lr> |
| 187 | + step_size: !ref <step_size> |
0 commit comments