-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathdynamic_chunk_training.py
More file actions
188 lines (156 loc) · 6.75 KB
/
dynamic_chunk_training.py
File metadata and controls
188 lines (156 loc) · 6.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""Configuration and utility classes for classes for Dynamic Chunk Training, as
often used for the training of streaming-capable models in speech recognition.
The definition of Dynamic Chunk Training is based on that of the following
paper, though a lot of the literature refers to the same definition:
https://arxiv.org/abs/2012.05481
Authors
* Sylvain de Langen 2023
"""
from dataclasses import dataclass
from typing import Optional
import torch
import speechbrain as sb
# NOTE: this configuration object is intended to be relatively specific to
# Dynamic Chunk Training; if you want to implement a different similar type of
# chunking different from that, you should consider using a different object.
@dataclass
class DynChunkTrainConfig:
"""Dynamic Chunk Training configuration object for use with transformers,
often in ASR for streaming.
This object may be used both to configure masking at training time and for
run-time configuration of DynChunkTrain-ready models.
"""
chunk_size: int
"""Size in frames of a single chunk, always `>0`.
If chunkwise streaming should be disabled at some point, pass an optional
streaming config parameter."""
left_context_size: Optional[int] = None
"""Number of *chunks* (not frames) visible to the left, always `>=0`.
If zero, then chunks can never attend to any past chunk.
If `None`, the left context is infinite (but use
`.is_infinite_left_context` for such a check)."""
def is_infinite_left_context(self) -> bool:
"""Returns true if the left context is infinite (i.e. any chunk can
attend to any past frame).
"""
return self.left_context_size is None
def left_context_size_frames(self) -> Optional[int]:
"""Returns the number of left context *frames* (not chunks).
If ``None``, the left context is infinite.
See also the ``left_context_size`` field.
"""
if self.left_context_size is None:
return None
return self.chunk_size * self.left_context_size
@dataclass
class DynChunkTrainConfigRandomSampler:
"""Helper class to generate a DynChunkTrainConfig at runtime depending on the current
stage.
Example
-------
>>> from speechbrain.core import Stage
>>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig
>>> from speechbrain.utils.dynamic_chunk_training import (
... DynChunkTrainConfigRandomSampler,
... )
>>> # for the purpose of this example, we test a scenario with a 100%
>>> # chance of the (24, None) scenario to occur
>>> sampler = DynChunkTrainConfigRandomSampler(
... chunkwise_prob=1.0,
... chunk_size_min=24,
... chunk_size_max=24,
... limited_left_context_prob=0.0,
... left_context_chunks_min=16,
... left_context_chunks_max=16,
... test_config=DynChunkTrainConfig(32, 16),
... valid_config=None,
... )
>>> one_train_config = sampler(Stage.TRAIN)
>>> one_train_config
DynChunkTrainConfig(chunk_size=24, left_context_size=None)
>>> one_train_config.is_infinite_left_context()
True
>>> sampler(Stage.TEST)
DynChunkTrainConfig(chunk_size=32, left_context_size=16)
"""
chunkwise_prob: float
"""When sampling (during `Stage.TRAIN`), the probability that a finite chunk
size will be used.
In the other case, any chunk can attend to the full past and future
context."""
chunk_size_min: int
"""When sampling a random chunk size, the minimum chunk size that can be
picked."""
chunk_size_max: int
"""When sampling a random chunk size, the maximum chunk size that can be
picked."""
limited_left_context_prob: float
"""When sampling a random chunk size, the probability that the left context
will be limited.
In the other case, any chunk can attend to the full past context."""
left_context_chunks_min: int
"""When sampling a random left context size, the minimum number of left
context chunks that can be picked."""
left_context_chunks_max: int
"""When sampling a random left context size, the maximum number of left
context chunks that can be picked."""
test_config: Optional[DynChunkTrainConfig] = None
"""The configuration that should be used for `Stage.TEST`.
When `None`, evaluation is done with full context (i.e. non-streaming)."""
valid_config: Optional[DynChunkTrainConfig] = None
"""The configuration that should be used for `Stage.VALID`.
When `None`, evaluation is done with full context (i.e. non-streaming)."""
def _sample_bool(self, prob):
"""Samples a random boolean with a probability, in a way that depends on
PyTorch's RNG seed.
Arguments
---------
prob : float
Probability (0..1) to return True (False otherwise).
Returns
-------
The sampled boolean
"""
return torch.rand((1,)).item() < prob
def __call__(self, stage):
"""In training stage, samples a random DynChunkTrain configuration.
During validation or testing, returns the relevant configuration.
Arguments
---------
stage : speechbrain.core.Stage
Current stage of training or evaluation.
In training mode, a random DynChunkTrainConfig will be sampled
according to the specified probabilities and ranges.
During evaluation, the relevant DynChunkTrainConfig attribute will
be picked.
Returns
-------
The appropriate configuration
"""
if stage == sb.core.Stage.TRAIN:
# When training for streaming, for each batch, we have a
# `dynamic_chunk_prob` probability of sampling a chunk size
# between `dynamic_chunk_min` and `_max`, otherwise output
# frames can see anywhere in the future.
if self._sample_bool(self.chunkwise_prob):
chunk_size = torch.randint(
self.chunk_size_min,
self.chunk_size_max + 1,
(1,),
).item()
if self._sample_bool(self.limited_left_context_prob):
left_context_chunks = torch.randint(
self.left_context_chunks_min,
self.left_context_chunks_max + 1,
(1,),
).item()
else:
left_context_chunks = None
return DynChunkTrainConfig(chunk_size, left_context_chunks)
return None
elif stage == sb.core.Stage.TEST:
return self.test_config
elif stage == sb.core.Stage.VALID:
return self.valid_config
else:
raise AttributeError(f"Unsupported stage found {stage}")