"""Author: Awni HannunThis is an example CTC decoder written in Python. The code isintended to be a simple example and is not designed to beespecially efficient.The algorithm is a prefix beam search for a model trainedwith the CTC loss function.For more details checkout either of these references: https://distill.pub/2017/ctc/#inference https://arxiv.org/abs/1408.2873"""importnumpyasnpimportmathimportcollectionsNEG_INF=-float("inf")defdecode(probs,beam_size=10,blank=0):""" Performs inference for the given output probabilities. Arguments: probs: The output probabilities (e.g. log post-softmax) for each time step. Should be an array of shape (time x output dim). beam_size (int): Size of the beam to use during inference. blank (int): Index of the CTC blank label. Returns the output label sequence and the corresponding negative log-likelihood estimated by the decoder. """T,S=probs.shape# Elements in the beam are (prefix, (p_blank, p_no_blank))# Initialize the beam with the empty sequence, a probability of# 1 for ending in blank and zero for ending in non-blank# (in log space).beam=[(tuple(),(0.0,NEG_INF))]fortinrange(T):# Loop over time# A default dictionary to store the next step candidates.next_beam=make_new_beam()forsinrange(S):# Loop over vocabp=probs[t,s]# The variables p_b and p_nb are respectively the# probabilities for the prefix given that it ends in a# blank and does not end in a blank at this time step.forprefix,(p_b,p_nb)inbeam:# Loop over beam# 情况1# If we propose a blank the prefix doesn't change.# Only the probability of ending in blank gets updated.ifs==blank:# n_p_b:当前时刻blank概率,n_p_nb:当前时刻非blank概率n_p_b,n_p_nb=next_beam[prefix]# logsumexp(n_p_b, p_b + p, p_nb + p): # n_p_b x ((p_b + p) + (p_nb + p))n_p_b=logsumexp(n_p_b,p_b+p,p_nb+p)next_beam[prefix]=(n_p_b,n_p_nb)continue# Extend the prefix by the new character s and add it to# the beam. Only the probability of not ending in blank# gets updated.end_t=prefix[-1]ifprefixelseNonen_prefix=prefix+(s,)n_p_b,n_p_nb=next_beam[n_prefix]ifs!=end_t:# 情况4n_p_nb=logsumexp(n_p_nb,p_b+p,p_nb+p)else:# 情况2和3# We don't include the previous probability of not ending# in blank (p_nb) if s is repeated at the end. The CTC# algorithm merges characters not separated by a blank.n_p_nb=logsumexp(n_p_nb,p_b+p)# *NB* this would be a good place to include an LM score.next_beam[n_prefix]=(n_p_b,n_p_nb)# If s is repeated at the end we also update the unchanged# prefix. This is the merging case.ifs==end_t:# 情况2,规整字符串不变,但规整概率需要更新n_p_b,n_p_nb=next_beam[prefix]n_p_nb=logsumexp(n_p_nb,p_nb+p)next_beam[prefix]=(n_p_b,n_p_nb)# Sort and trim the beam before moving on to the# next time-step.beam=sorted(next_beam.items(),key=lambdax:logsumexp(*x[1]),reverse=True)beam=beam[:beam_size]best=beam[0]returnbest[0],-logsumexp(*best[1])
// wenet/runtime/core/decoder/ctc_prefix_beam_search.cc// 1. First beam prune, only select topk candidatesstd::tuple<Tensor,Tensor>topk=logp_t.topk(opts_.first_beam_size);Tensortopk_score=std::get<0>(topk);Tensortopk_index=std::get<1>(topk);
// 3. Second beam prune, only keep top n best pathsstd::vector<std::pair<std::vector<int>,PrefixScore>>arr(next_hyps.begin(),next_hyps.end());intsecond_beam_size=std::min(static_cast<int>(arr.size()),opts_.second_beam_size);std::nth_element(arr.begin(),arr.begin()+second_beam_size,arr.end(),PrefixScoreCompare);arr.resize(second_beam_size);std::sort(arr.begin(),arr.end(),PrefixScoreCompare);