ref: ca2d0dc4c1d93490bb3129467003c8228928b87f
dir: /dnn/torch/lpcnet/utils/layers/subconditioner.py/
""" /* Copyright (c) 2023 Amazon Written by Jan Buethe */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ """ from re import sub import torch from torch import nn def get_subconditioner( method, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs): subconditioner_dict = { 'additive' : AdditiveSubconditioner, 'concatenative' : ConcatenativeSubconditioner, 'modulative' : ModulativeSubconditioner } return subconditioner_dict[method](number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs) class Subconditioner(nn.Module): def __init__(self): """ upsampling by subconditioning Upsamples a sequence of states conditioning on pcm signals and optionally a feature vector. """ super(Subconditioner, self).__init__() def forward(self, states, signals, features=None): raise Exception("Base class should not be called") def single_step(self, index, state, signals, features): raise Exception("Base class should not be called") def get_output_dim(self, index): raise Exception("Base class should not be called") class AdditiveSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs): """ subconditioning by addition """ super(AdditiveSubconditioner, self).__init__() self.number_of_subsamples = number_of_subsamples self.pcm_embedding_size = pcm_embedding_size self.state_size = state_size self.pcm_levels = pcm_levels self.number_of_signals = number_of_signals if self.pcm_embedding_size != self.state_size: raise ValueError('For additive subconditioning state and embedding ' + f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}') self.embeddings = [None] for i in range(1, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) def forward(self, states, signals): """ creates list of subconditioned states Parameters: ----------- states : torch.tensor states of shape (batch, seq_length // s, state_size) signals : torch.tensor signals of shape (batch, seq_length, number_of_signals) Returns: -------- c_states : list of torch.tensor list of s subconditioned states """ s = self.number_of_subsamples c_states = [states] new_states = states for i in range(1, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.sum(embed, dim=2) new_states = new_states + embed c_states.append(new_states) return c_states def single_step(self, index, state, signals): """ carry out single step for inference Parameters: ----------- index : int position in subconditioning batch state : torch.tensor state to sub-condition signals : torch.tensor signals for subconditioning, all but the last dimensions must match those of state Returns: c_state : torch.tensor subconditioned state """ if index == 0: c_state = state else: embed_signals = self.embeddings[index](signals) c = torch.sum(embed_signals, dim=-2) c_state = state + c return c_state def get_output_dim(self, index): return self.state_size def get_average_flops_per_step(self): s = self.number_of_subsamples flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size return flops class ConcatenativeSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, recurrent=True, **kwargs): """ subconditioning by concatenation """ super(ConcatenativeSubconditioner, self).__init__() self.number_of_subsamples = number_of_subsamples self.pcm_embedding_size = pcm_embedding_size self.state_size = state_size self.pcm_levels = pcm_levels self.number_of_signals = number_of_signals self.recurrent = recurrent self.embeddings = [] start_index = 0 if self.recurrent: start_index = 1 self.embeddings.append(None) for i in range(start_index, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) def forward(self, states, signals): """ creates list of subconditioned states Parameters: ----------- states : torch.tensor states of shape (batch, seq_length // s, state_size) signals : torch.tensor signals of shape (batch, seq_length, number_of_signals) Returns: -------- c_states : list of torch.tensor list of s subconditioned states """ s = self.number_of_subsamples if self.recurrent: c_states = [states] start = 1 else: c_states = [] start = 0 new_states = states for i in range(start, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.flatten(embed, -2) if self.recurrent: new_states = torch.cat((new_states, embed), dim=-1) else: new_states = torch.cat((states, embed), dim=-1) c_states.append(new_states) return c_states def single_step(self, index, state, signals): """ carry out single step for inference Parameters: ----------- index : int position in subconditioning batch state : torch.tensor state to sub-condition signals : torch.tensor signals for subconditioning, all but the last dimensions must match those of state Returns: c_state : torch.tensor subconditioned state """ if index == 0 and self.recurrent: c_state = state else: embed_signals = self.embeddings[index](signals) c = torch.flatten(embed_signals, -2) if not self.recurrent and index > 0: # overwrite previous conditioning vector c_state = torch.cat((state[...,:self.state_size], c), dim=-1) else: c_state = torch.cat((state, c), dim=-1) return c_state return c_state def get_average_flops_per_step(self): return 0 def get_output_dim(self, index): if self.recurrent: return self.state_size + index * self.pcm_embedding_size * self.number_of_signals else: return self.state_size + self.pcm_embedding_size * self.number_of_signals class ModulativeSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, state_recurrent=False, **kwargs): """ subconditioning by modulation """ super(ModulativeSubconditioner, self).__init__() self.number_of_subsamples = number_of_subsamples self.pcm_embedding_size = pcm_embedding_size self.state_size = state_size self.pcm_levels = pcm_levels self.number_of_signals = number_of_signals self.state_recurrent = state_recurrent self.hidden_size = self.pcm_embedding_size * self.number_of_signals if self.state_recurrent: self.hidden_size += self.pcm_embedding_size self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size) self.embeddings = [None] self.alphas = [None] self.betas = [None] for i in range(1, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) self.alphas.append(nn.Linear(self.hidden_size, self.state_size)) self.add_module('alpha_dense_' + str(i), self.alphas[-1]) self.betas.append(nn.Linear(self.hidden_size, self.state_size)) self.add_module('beta_dense_' + str(i), self.betas[-1]) def forward(self, states, signals): """ creates list of subconditioned states Parameters: ----------- states : torch.tensor states of shape (batch, seq_length // s, state_size) signals : torch.tensor signals of shape (batch, seq_length, number_of_signals) Returns: -------- c_states : list of torch.tensor list of s subconditioned states """ s = self.number_of_subsamples c_states = [states] new_states = states for i in range(1, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.flatten(embed, -2) if self.state_recurrent: comp_states = self.state_transform(new_states) embed = torch.cat((embed, comp_states), dim=-1) alpha = torch.tanh(self.alphas[i](embed)) beta = torch.tanh(self.betas[i](embed)) # new state obtained by modulating previous state new_states = torch.tanh((1 + alpha) * new_states + beta) c_states.append(new_states) return c_states def single_step(self, index, state, signals): """ carry out single step for inference Parameters: ----------- index : int position in subconditioning batch state : torch.tensor state to sub-condition signals : torch.tensor signals for subconditioning, all but the last dimensions must match those of state Returns: c_state : torch.tensor subconditioned state """ if index == 0: c_state = state else: embed_signals = self.embeddings[index](signals) c = torch.flatten(embed_signals, -2) if self.state_recurrent: r_state = self.state_transform(state) c = torch.cat((c, r_state), dim=-1) alpha = torch.tanh(self.alphas[index](c)) beta = torch.tanh(self.betas[index](c)) c_state = torch.tanh((1 + alpha) * state + beta) return c_state return c_state def get_output_dim(self, index): return self.state_size def get_average_flops_per_step(self): s = self.number_of_subsamples # estimate activation by 10 flops # c_state = torch.tanh((1 + alpha) * state + beta) flops = 13 * self.state_size # hidden size hidden_size = self.number_of_signals * self.pcm_embedding_size if self.state_recurrent: hidden_size += self.pcm_embedding_size # counting 2 * A * B flops for Linear(A, B) # alpha = torch.tanh(self.alphas[index](c)) # beta = torch.tanh(self.betas[index](c)) flops += 4 * hidden_size * self.state_size + 20 * self.state_size # r_state = self.state_transform(state) if self.state_recurrent: flops += 2 * self.state_size * self.pcm_embedding_size # average over steps flops *= (s - 1) / s return flops class ComparitiveSubconditioner(Subconditioner): def __init__(self, number_of_subsamples, pcm_embedding_size, state_size, pcm_levels, number_of_signals, error_index=-1, apply_gate=True, normalize=False): """ subconditioning by comparison """ super(ComparitiveSubconditioner, self).__init__() self.comparison_size = self.pcm_embedding_size self.error_position = error_index self.apply_gate = apply_gate self.normalize = normalize self.state_transform = nn.Linear(self.state_size, self.comparison_size) self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size) if self.apply_gate: self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size) # embeddings and state transforms self.embeddings = [None] self.alpha_denses = [None] self.beta_denses = [None] self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)] self.add_module('state_transform_0', self.state_transforms[0]) for i in range(1, self.number_of_subsamples): embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size) self.add_module('pcm_embedding_' + str(i), embedding) self.embeddings.append(embedding) state_transform = nn.Linear(self.state_size, self.comparison_size) self.add_module('state_transform_' + str(i), state_transform) self.state_transforms.append(state_transform) self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1]) self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)) self.add_module('beta_dense_' + str(i), self.beta_denses[-1]) def forward(self, states, signals): s = self.number_of_subsamples c_states = [states] new_states = states for i in range(1, self.number_of_subsamples): embed = self.embeddings[i](signals[:, i::s]) # reduce signal dimension embed = torch.flatten(embed, -2) comp_states = self.state_transforms[i](new_states) alpha = torch.tanh(self.alpha_dense(embed)) beta = torch.tanh(self.beta_dense(embed)) # new state obtained by modulating previous state new_states = torch.tanh((1 + alpha) * comp_states + beta) c_states.append(new_states) return c_states