ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/utils/layers/subconditioner.py/
from re import sub
import torch
from torch import nn
def get_subconditioner( method,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
subconditioner_dict = {
'additive' : AdditiveSubconditioner,
'concatenative' : ConcatenativeSubconditioner,
'modulative' : ModulativeSubconditioner
}
return subconditioner_dict[method](number_of_subsamples,
pcm_embedding_size, state_size, pcm_levels, number_of_signals, **kwargs)
class Subconditioner(nn.Module):
def __init__(self):
""" upsampling by subconditioning
Upsamples a sequence of states conditioning on pcm signals and
optionally a feature vector.
"""
super(Subconditioner, self).__init__()
def forward(self, states, signals, features=None):
raise Exception("Base class should not be called")
def single_step(self, index, state, signals, features):
raise Exception("Base class should not be called")
def get_output_dim(self, index):
raise Exception("Base class should not be called")
class AdditiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
**kwargs):
""" subconditioning by addition """
super(AdditiveSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
if self.pcm_embedding_size != self.state_size:
raise ValueError('For additive subconditioning state and embedding '
+ f'sizes must match but but got {self.state_size} and {self.pcm_embedding_size}')
self.embeddings = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.sum(embed, dim=2)
new_states = new_states + embed
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.sum(embed_signals, dim=-2)
c_state = state + c
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
flops = (s - 1) / s * self.number_of_signals * self.pcm_embedding_size
return flops
class ConcatenativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
recurrent=True,
**kwargs):
""" subconditioning by concatenation """
super(ConcatenativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.recurrent = recurrent
self.embeddings = []
start_index = 0
if self.recurrent:
start_index = 1
self.embeddings.append(None)
for i in range(start_index, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
if self.recurrent:
c_states = [states]
start = 1
else:
c_states = []
start = 0
new_states = states
for i in range(start, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.recurrent:
new_states = torch.cat((new_states, embed), dim=-1)
else:
new_states = torch.cat((states, embed), dim=-1)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0 and self.recurrent:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if not self.recurrent and index > 0:
# overwrite previous conditioning vector
c_state = torch.cat((state[...,:self.state_size], c), dim=-1)
else:
c_state = torch.cat((state, c), dim=-1)
return c_state
return c_state
def get_average_flops_per_step(self):
return 0
def get_output_dim(self, index):
if self.recurrent:
return self.state_size + index * self.pcm_embedding_size * self.number_of_signals
else:
return self.state_size + self.pcm_embedding_size * self.number_of_signals
class ModulativeSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
state_recurrent=False,
**kwargs):
""" subconditioning by modulation """
super(ModulativeSubconditioner, self).__init__()
self.number_of_subsamples = number_of_subsamples
self.pcm_embedding_size = pcm_embedding_size
self.state_size = state_size
self.pcm_levels = pcm_levels
self.number_of_signals = number_of_signals
self.state_recurrent = state_recurrent
self.hidden_size = self.pcm_embedding_size * self.number_of_signals
if self.state_recurrent:
self.hidden_size += self.pcm_embedding_size
self.state_transform = nn.Linear(self.state_size, self.pcm_embedding_size)
self.embeddings = [None]
self.alphas = [None]
self.betas = [None]
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
self.alphas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alphas[-1])
self.betas.append(nn.Linear(self.hidden_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.betas[-1])
def forward(self, states, signals):
""" creates list of subconditioned states
Parameters:
-----------
states : torch.tensor
states of shape (batch, seq_length // s, state_size)
signals : torch.tensor
signals of shape (batch, seq_length, number_of_signals)
Returns:
--------
c_states : list of torch.tensor
list of s subconditioned states
"""
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
if self.state_recurrent:
comp_states = self.state_transform(new_states)
embed = torch.cat((embed, comp_states), dim=-1)
alpha = torch.tanh(self.alphas[i](embed))
beta = torch.tanh(self.betas[i](embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * new_states + beta)
c_states.append(new_states)
return c_states
def single_step(self, index, state, signals):
""" carry out single step for inference
Parameters:
-----------
index : int
position in subconditioning batch
state : torch.tensor
state to sub-condition
signals : torch.tensor
signals for subconditioning, all but the last dimensions
must match those of state
Returns:
c_state : torch.tensor
subconditioned state
"""
if index == 0:
c_state = state
else:
embed_signals = self.embeddings[index](signals)
c = torch.flatten(embed_signals, -2)
if self.state_recurrent:
r_state = self.state_transform(state)
c = torch.cat((c, r_state), dim=-1)
alpha = torch.tanh(self.alphas[index](c))
beta = torch.tanh(self.betas[index](c))
c_state = torch.tanh((1 + alpha) * state + beta)
return c_state
return c_state
def get_output_dim(self, index):
return self.state_size
def get_average_flops_per_step(self):
s = self.number_of_subsamples
# estimate activation by 10 flops
# c_state = torch.tanh((1 + alpha) * state + beta)
flops = 13 * self.state_size
# hidden size
hidden_size = self.number_of_signals * self.pcm_embedding_size
if self.state_recurrent:
hidden_size += self.pcm_embedding_size
# counting 2 * A * B flops for Linear(A, B)
# alpha = torch.tanh(self.alphas[index](c))
# beta = torch.tanh(self.betas[index](c))
flops += 4 * hidden_size * self.state_size + 20 * self.state_size
# r_state = self.state_transform(state)
if self.state_recurrent:
flops += 2 * self.state_size * self.pcm_embedding_size
# average over steps
flops *= (s - 1) / s
return flops
class ComparitiveSubconditioner(Subconditioner):
def __init__(self,
number_of_subsamples,
pcm_embedding_size,
state_size,
pcm_levels,
number_of_signals,
error_index=-1,
apply_gate=True,
normalize=False):
""" subconditioning by comparison """
super(ComparitiveSubconditioner, self).__init__()
self.comparison_size = self.pcm_embedding_size
self.error_position = error_index
self.apply_gate = apply_gate
self.normalize = normalize
self.state_transform = nn.Linear(self.state_size, self.comparison_size)
self.alpha_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
self.beta_dense = nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size)
if self.apply_gate:
self.gate_dense = nn.Linear(self.pcm_embedding_size, self.state_size)
# embeddings and state transforms
self.embeddings = [None]
self.alpha_denses = [None]
self.beta_denses = [None]
self.state_transforms = [nn.Linear(self.state_size, self.comparison_size)]
self.add_module('state_transform_0', self.state_transforms[0])
for i in range(1, self.number_of_subsamples):
embedding = nn.Embedding(self.pcm_levels, self.pcm_embedding_size)
self.add_module('pcm_embedding_' + str(i), embedding)
self.embeddings.append(embedding)
state_transform = nn.Linear(self.state_size, self.comparison_size)
self.add_module('state_transform_' + str(i), state_transform)
self.state_transforms.append(state_transform)
self.alpha_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('alpha_dense_' + str(i), self.alpha_denses[-1])
self.beta_denses.append(nn.Linear(self.number_of_signales * self.pcm_embedding_size, self.state_size))
self.add_module('beta_dense_' + str(i), self.beta_denses[-1])
def forward(self, states, signals):
s = self.number_of_subsamples
c_states = [states]
new_states = states
for i in range(1, self.number_of_subsamples):
embed = self.embeddings[i](signals[:, i::s])
# reduce signal dimension
embed = torch.flatten(embed, -2)
comp_states = self.state_transforms[i](new_states)
alpha = torch.tanh(self.alpha_dense(embed))
beta = torch.tanh(self.beta_dense(embed))
# new state obtained by modulating previous state
new_states = torch.tanh((1 + alpha) * comp_states + beta)
c_states.append(new_states)
return c_states