-
Notifications
You must be signed in to change notification settings - Fork 1.5k
CTC beamsearch decoding via ctcdecode #773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@30stomercury Here. |
@Antoine-Caubriere Could you please add an exemple in your class ? Just like for other decoders ? This will help the user to understand how to use it. |
Hi @Antoine-Caubriere , can you share some examples or details of your experiments? Thanks. |
Hi @TParcollet, @30stomercury, This is very simple to use this implementation. First, you have to add the CTCdeccodeBeamSearch in your YAML configuration file. An example for my experiments : ctc_beam_search_module: !new:speechbrain.decoders.CTCDecodeBeamSearch
labels: ['<blank>', 'u', ' ', 'r', 'e', 'v', 'o', 'i', 'm', 'c', 's', 'x', 'n', 't', 'h', "'", 'p', 'f', 'd', 'b', 'é', 'z', 'ç', 'è', 'à', 'l', 'g', 'j', 'ô', 'k', 'q', 'â', 'ù', 'î', 'y', 'û', 'ë', 'ê', 'w', '[', ']', 'ï', 'a', '<unk>']
model_path: /users/lm/media.noconcept.4g.unk.mmap
log_probs_input: True Then you can simply use him in the compute_objectives. An example : def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the output probabilities."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
# Forward pass
feats = self.modules.wav2vec2(wavs)
x = self.modules.enc(feats)
logits = self.modules.ctc_lin(x)
p_ctc = self.hparams.log_softmax(logits)
return p_ctc, wav_lens
def compute_objectives(self, predictions, batch, stage):
"""Computes the CTC loss given predictions and targets."""
p_ctc, wav_lens = predictions
chars, char_lens = batch.char_encoded
loss = self.hparams.ctc_cost(
p_ctc, chars, wav_lens, char_lens
)
self.ctc_metrics.append(batch.id, p_ctc, chars, wav_lens, char_lens)
if stage != sb.Stage.TRAIN:
if(hparams["ctc_beam_search"]): # <--- here, a boolean in the configuration file
sequence = self.hparams.ctc_beam_search_module(p_ctc)
else:
sequence = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, self.hparams.blank_index)
self.cer_metrics.append(
ids=batch.id,
predict=sequence,
target=chars,
target_len=char_lens,
ind2lab=self.label_encoder.decode_ndim,
)
return loss By default, blank token is on the index 0. But you can change this in the YAML if you need. (As all the others parameters (alpha, beta, beam_width...)). Also, you can use the "nBest" parameter to extract more than the 1Best. example : sequences = self.hparams.ctc_beam_search_module(p_ctc,nBest=5) |
Hi @Antoine-Caubriere , thanks. The |
Thank you @Antoine-Caubriere! That looks a nice addition. Our goal is to integrate our code with k2 and wfst, but this looks like a reasonable solution for now. |
Hi @Antoine-Caubriere , thanks for this pr. Have you tested this ctc decode with wordpiece tokens? On my side the ngram model can only be wordpiece-based ngram model. |
@30stomercury Can't we simply extract the labels from the currently loaded tokenizer? @30stomercury since you already have BPE-based N-Gram LM, could you try to see if it works? I don't think @Antoine-Caubriere has one. |
I would like to merge this PR ASAP ... |
Hi @TParcollet , we can merge it first. I think labels can be obtained from the currently loaded tokenizer in |
It looks like tests are failing. We have to think about what we would like to do here. This c++ solution seems very fast, while the solution that @30stomercury is setting up is slower but more flexible (it supports many things, including CTC + Transformer LM). Moreover, in the future we might consider the integration with WFST of K2. The risk is to have too many search solutions and the users can get lost easily here. What do you think? |
This is a quick patch, if we do not advertise it. @30stomercury, do you have an estimate on the time needed to end your work? 1 month? I think we can merge this, and remove it later. We can drop the support of features, just like we will do with Fairseq w2v2 once I integrated the pretraining phase. |
@mravanelli let's discuss this on Slack later today. Maybe if makes sense that we wait for @30stomercury solution, and we just redirect peoples interested to this PR ? |
Hi @TParcollet , I've discussed with Mirco in the meeting this week. The main blocks have been done and the performance looks reasonable. I will adapt it to all recipes once other developers are okay with the implementation of scoring part in #751 . |
hey, another more lightweight and maybe more flexible option could be: https://github.com/kensho-technologies/pyctcdecode |
I think we should close it this as we are moving to k2 FST, right? @TParcollet and @Antoine-Caubriere |
Yes |
@mravanelli I was wondering last time if we shouldn't bring this class back. I mean, for now, we have absolutely no n-gram rescoring but only LM fusion (that doesn't work with word-level LM). Maybe we should add it as it always is great to be compliant with CTCdecode? I dunno. |
Quick add : Beam search decoding using this implementation : https://github.com/parlance/ctcdecode