@@ -32,10 +32,10 @@ class S2SBaseSearcher(torch.nn.Module):
3232
3333 Returns
3434 -------
35- predictions:
35+ predictions
3636 Outputs as Python list of lists, with "ragged" dimensions; padding
3737 has been removed.
38- scores:
38+ scores
3939 The sum of log probabilities (and possibly
4040 additional heuristic scores) for each prediction.
4141
@@ -864,8 +864,9 @@ class S2SRNNBeamSearcher(S2SBeamSearcher):
864864 This class implements the beam search decoding
865865 for AttentionalRNNDecoder (speechbrain/nnet/RNN.py).
866866 See also S2SBaseSearcher(), S2SBeamSearcher().
867- Parameters
868- ----------
867+
868+ Arguments
869+ ---------
869870 embedding : torch.nn.Module
870871 An embedding layer
871872 decoder : torch.nn.Module
@@ -877,6 +878,7 @@ class S2SRNNBeamSearcher(S2SBeamSearcher):
877878 distribution, being softer when T>1 and sharper with T<1.
878879 **kwargs
879880 see S2SBeamSearcher, arguments are directly passed
881+
880882 Example
881883 -------
882884 >>> emb = torch.nn.Embedding(5, 3)
@@ -967,8 +969,9 @@ class S2SRNNBeamSearchLM(S2SRNNBeamSearcher):
967969 This class implements the beam search decoding
968970 for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM.
969971 See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().
970- Parameters
971- ----------
972+
973+ Arguments
974+ ---------
972975 embedding : torch.nn.Module
973976 An embedding layer
974977 decoder : torch.nn.Module
@@ -982,6 +985,7 @@ class S2SRNNBeamSearchLM(S2SRNNBeamSearcher):
982985 distribution, being softer when T>1 and sharper with T<1.
983986 **kwargs
984987 Arguments to pass to S2SBeamSearcher
988+
985989 Example
986990 -------
987991 >>> from speechbrain.lobes.models.RNNLM import RNNLM
@@ -1060,8 +1064,9 @@ class S2SRNNBeamSearchTransformerLM(S2SRNNBeamSearcher):
10601064 This class implements the beam search decoding
10611065 for AttentionalRNNDecoder (speechbrain/nnet/RNN.py) with LM.
10621066 See also S2SBaseSearcher(), S2SBeamSearcher(), S2SRNNBeamSearcher().
1063- Parameters
1064- ----------
1067+
1068+ Arguments
1069+ ---------
10651070 embedding : torch.nn.Module
10661071 An embedding layer
10671072 decoder : torch.nn.Module
@@ -1075,6 +1080,7 @@ class S2SRNNBeamSearchTransformerLM(S2SRNNBeamSearcher):
10751080 distribution, being softer when T>1 and sharper with T<1.
10761081 **kwargs
10771082 Arguments to pass to S2SBeamSearcher
1083+
10781084 Example
10791085 -------
10801086 >>> from speechbrain.lobes.models.transformer.TransformerLM import TransformerLM
@@ -1142,8 +1148,8 @@ def inflate_tensor(tensor, times, dim):
11421148 """
11431149 This function inflates the tensor for times along dim.
11441150
1145- Parameters
1146- ----------
1151+ Arguments
1152+ ---------
11471153 tensor : torch.Tensor
11481154 The tensor to be inflated.
11491155 times : int
@@ -1173,8 +1179,8 @@ def mask_by_condition(tensor, cond, fill_value):
11731179 """
11741180 This function will mask some element in the tensor with fill_value, if condition=False.
11751181
1176- Parameters
1177- ----------
1182+ Arguments
1183+ ---------
11781184 tensor : torch.Tensor
11791185 The tensor to be masked.
11801186 cond : torch.BoolTensor
@@ -1283,8 +1289,8 @@ def lm_forward_step(self, inp_tokens, memory):
12831289def batch_filter_seq2seq_output (prediction , eos_id = - 1 ):
12841290 """Calling batch_size times of filter_seq2seq_output.
12851291
1286- Parameters
1287- ----------
1292+ Arguments
1293+ ---------
12881294 prediction : list of torch.Tensor
12891295 a list containing the output ints predicted by the seq2seq system.
12901296 eos_id : int, string
@@ -1297,10 +1303,10 @@ def batch_filter_seq2seq_output(prediction, eos_id=-1):
12971303
12981304 Example
12991305 -------
1300- >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
1301- >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
1302- >>> predictions
1303- [[1, 2, 3], [2, 3]]
1306+ >>> predictions = [torch.IntTensor([1,2,3,4]), torch.IntTensor([2,3,4,5,6])]
1307+ >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4)
1308+ >>> predictions
1309+ [[1, 2, 3], [2, 3]]
13041310 """
13051311 outputs = []
13061312 for p in prediction :
@@ -1312,8 +1318,8 @@ def batch_filter_seq2seq_output(prediction, eos_id=-1):
13121318def filter_seq2seq_output (string_pred , eos_id = - 1 ):
13131319 """Filter the output until the first eos occurs (exclusive).
13141320
1315- Parameters
1316- ----------
1321+ Arguments
1322+ ---------
13171323 string_pred : list
13181324 a list containing the output strings/ints predicted by the seq2seq system.
13191325 eos_id : int, string
@@ -1326,10 +1332,10 @@ def filter_seq2seq_output(string_pred, eos_id=-1):
13261332
13271333 Example
13281334 -------
1329- >>> string_pred = ['a','b','c','d','eos','e']
1330- >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
1331- >>> string_out
1332- ['a', 'b', 'c', 'd']
1335+ >>> string_pred = ['a','b','c','d','eos','e']
1336+ >>> string_out = filter_seq2seq_output(string_pred, eos_id='eos')
1337+ >>> string_out
1338+ ['a', 'b', 'c', 'd']
13331339 """
13341340 if isinstance (string_pred , list ):
13351341 try :
0 commit comments