ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/utils/sample.py/
import torch
def sample_excitation(probs, pitch_corr):
norm = lambda x : x / (x.sum() + 1e-18)
# lowering the temperature
probs = norm(probs ** (1 + max(0, 1.5 * pitch_corr - 0.5)))
# cut-off tails
probs = norm(torch.maximum(probs - 0.002 , torch.FloatTensor([0])))
# sample
exc = torch.multinomial(probs.squeeze(), 1)
return exc