ref: 0fed741a87ccc061eff382d306fadd71acdfc57d
dir: /dnn/torch/rdovae/rdovae/rdovae.py/
""" /* Copyright (c) 2022 Amazon Written by Jan Buethe */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ """ """ Pytorch implementations of rate distortion optimized variational autoencoder """ import math as m import torch from torch import nn import torch.nn.functional as F import sys import os source_dir = os.path.split(os.path.abspath(__file__))[0] sys.path.append(os.path.join(source_dir, "../../lpcnet/")) from utils.sparsification import GRUSparsifier from torch.nn.utils import weight_norm # Quantization and rate related utily functions def soft_pvq(x, k): """ soft pyramid vector quantizer """ # L2 normalization x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True)) with torch.no_grad(): # quantization loop, no need to track gradients here x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True) # set initial scaling factor to k scale_factor = k x_scaled = scale_factor * x_norm1 x_quant = torch.round(x_scaled) # we aim for ||x_quant||_L1 = k for _ in range(10): # remove signs and calculate L1 norm abs_x_quant = torch.abs(x_quant) abs_x_scaled = torch.abs(x_scaled) l1_x_quant = torch.sum(abs_x_quant, axis=-1) # increase, where target is too small and decrease, where target is too large plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values factor = torch.where(l1_x_quant > k, minus, plus) factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor) scale_factor = scale_factor * factor.unsqueeze(-1) # update x x_scaled = scale_factor * x_norm1 x_quant = torch.round(x_quant) # L2 normalization of quantized x x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True)) quantization_error = x_quant_norm2 - x_norm2 return x_norm2 + quantization_error.detach() def cache_parameters(func): cache = dict() def cached_func(*args): if args in cache: return cache[args] else: cache[args] = func(*args) return cache[args] return cached_func @cache_parameters def pvq_codebook_size(n, k): if k == 0: return 1 if n == 0: return 0 return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1) def soft_rate_estimate(z, r, reduce=True): """ rate approximation with dependent theta Eq. (7)""" rate = torch.sum( - torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6), dim=-1 ) if reduce: rate = torch.mean(rate) return rate def hard_rate_estimate(z, r, theta, reduce=True): """ hard rate approximation """ z_q = torch.round(z) p0 = 1 - r ** (0.5 + 0.5 * theta) alpha = torch.relu(1 - torch.abs(z_q)) ** 2 rate = - torch.sum( (alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6) + (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)), dim=-1 ) if reduce: rate = torch.mean(rate) return rate def soft_dead_zone(x, dead_zone): """ approximates application of a dead zone to x """ d = dead_zone * 0.05 return x - d * torch.tanh(x / (0.1 + d)) def hard_quantize(x): """ round with copy gradient trick """ return x + (torch.round(x) - x).detach() def noise_quantize(x): """ simulates quantization with addition of random uniform noise """ return x + (torch.rand_like(x) - 0.5) # loss functions def distortion_loss(y_true, y_pred, rate_lambda=None): """ custom distortion loss for LPCNet features """ if y_true.size(-1) != 20: raise ValueError('distortion loss is designed to work with 20 features') ceps_error = y_pred[..., :18] - y_true[..., :18] pitch_error = 2*(y_pred[..., 18:19] - y_true[..., 18:19]) corr_error = y_pred[..., 19:] - y_true[..., 19:] pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2 loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1) if type(rate_lambda) != type(None): loss = loss / torch.sqrt(rate_lambda) loss = torch.mean(loss) return loss # sampling functions import random def random_split(start, stop, num_splits=3, min_len=3): get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)]) candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] while get_min_len(candidate) < min_len: candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop] return candidate # weight initialization and clipping def init_weights(module): if isinstance(module, nn.GRU): for p in module.named_parameters(): if p[0].startswith('weight_hh_'): nn.init.orthogonal_(p[1]) def weight_clip_factory(max_value): """ weight clipping function concerning sum of abs values of adjecent weights """ def clip_weight_(w): stop = w.size(1) # omit last column if stop is odd if stop % 2: stop -= 1 max_values = max_value * torch.ones_like(w[:, :stop]) factor = max_value / torch.maximum(max_values, torch.repeat_interleave( torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]), 2, 1)) with torch.no_grad(): w[:, :stop] *= factor def clip_weights(module): if isinstance(module, nn.GRU) or isinstance(module, nn.Linear): for name, w in module.named_parameters(): if name.startswith('weight'): clip_weight_(w) return clip_weights def n(x): return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.) # RDOVAE module and submodules sparsify_start = 12000 sparsify_stop = 24000 sparsify_interval = 100 sparsify_exponent = 3 #sparsify_start = 0 #sparsify_stop = 0 sparse_params1 = { # 'W_hr' : (1.0, [8, 4], True), # 'W_hz' : (1.0, [8, 4], True), # 'W_hn' : (1.0, [8, 4], True), 'W_ir' : (0.6, [8, 4], False), 'W_iz' : (0.4, [8, 4], False), 'W_in' : (0.8, [8, 4], False) } sparse_params2 = { # 'W_hr' : (1.0, [8, 4], True), # 'W_hz' : (1.0, [8, 4], True), # 'W_hn' : (1.0, [8, 4], True), 'W_ir' : (0.3, [8, 4], False), 'W_iz' : (0.2, [8, 4], False), 'W_in' : (0.4, [8, 4], False) } class MyConv(nn.Module): def __init__(self, input_dim, output_dim, dilation=1): super(MyConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.dilation=dilation self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation) def forward(self, x, state=None): device = x.device conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1) return torch.tanh(self.conv(conv_in)).permute(0, 2, 1) class GLU(nn.Module): def __init__(self, feat_size): super(GLU, self).__init__() torch.manual_seed(5) self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False)) self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\ or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): nn.init.orthogonal_(m.weight.data) def forward(self, x): out = x * torch.sigmoid(self.gate(x)) return out class CoreEncoder(nn.Module): STATE_HIDDEN = 128 FRAMES_PER_STEP = 2 CONV_KERNEL_SIZE = 4 def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24): """ core encoder for RDOVAE Computes latents, initial states, and rate estimates from features and lambda parameter """ super(CoreEncoder, self).__init__() # hyper parameters self.feature_dim = feature_dim self.output_dim = output_dim self.cond_size = cond_size self.cond_size2 = cond_size2 self.state_size = state_size # derived parameters self.input_dim = self.FRAMES_PER_STEP * self.feature_dim # layers self.dense_1 = nn.Linear(self.input_dim, 64) self.gru1 = nn.GRU(64, 64, batch_first=True) self.conv1 = MyConv(128, 96) self.gru2 = nn.GRU(224, 64, batch_first=True) self.conv2 = MyConv(288, 96, dilation=2) self.gru3 = nn.GRU(384, 64, batch_first=True) self.conv3 = MyConv(448, 96, dilation=2) self.gru4 = nn.GRU(544, 64, batch_first=True) self.conv4 = MyConv(608, 96, dilation=2) self.gru5 = nn.GRU(704, 64, batch_first=True) self.conv5 = MyConv(768, 96, dilation=2) self.z_dense = nn.Linear(864, self.output_dim) self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN) self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size) nb_params = sum(p.numel() for p in self.parameters()) print(f"encoder: {nb_params} weights") # initialize weights self.apply(init_weights) def forward(self, features): # reshape features x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2))) batch = x.size(0) device = x.device # run encoding layer stack x = n(torch.tanh(self.dense_1(x))) x = torch.cat([x, n(self.gru1(x)[0])], -1) x = torch.cat([x, n(self.conv1(x))], -1) x = torch.cat([x, n(self.gru2(x)[0])], -1) x = torch.cat([x, n(self.conv2(x))], -1) x = torch.cat([x, n(self.gru3(x)[0])], -1) x = torch.cat([x, n(self.conv3(x))], -1) x = torch.cat([x, n(self.gru4(x)[0])], -1) x = torch.cat([x, n(self.conv4(x))], -1) x = torch.cat([x, n(self.gru5(x)[0])], -1) x = torch.cat([x, n(self.conv5(x))], -1) z = self.z_dense(x) # init state for decoder states = torch.tanh(self.state_dense_1(x)) states = self.state_dense_2(states) return z, states class CoreDecoder(nn.Module): FRAMES_PER_STEP = 4 def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24): """ core decoder for RDOVAE Computes features from latents, initial state, and quantization index """ super(CoreDecoder, self).__init__() # hyper parameters self.input_dim = input_dim self.output_dim = output_dim self.cond_size = cond_size self.cond_size2 = cond_size2 self.state_size = state_size self.input_size = self.input_dim # layers self.dense_1 = nn.Linear(self.input_size, 96) self.gru1 = nn.GRU(96, 96, batch_first=True) self.conv1 = MyConv(192, 32) self.gru2 = nn.GRU(224, 96, batch_first=True) self.conv2 = MyConv(320, 32) self.gru3 = nn.GRU(352, 96, batch_first=True) self.conv3 = MyConv(448, 32) self.gru4 = nn.GRU(480, 96, batch_first=True) self.conv4 = MyConv(576, 32) self.gru5 = nn.GRU(608, 96, batch_first=True) self.conv5 = MyConv(704, 32) self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim) self.glu1 = GLU(96) self.glu2 = GLU(96) self.glu3 = GLU(96) self.glu4 = GLU(96) self.glu5 = GLU(96) self.hidden_init = nn.Linear(self.state_size, 128) self.gru_init = nn.Linear(128, 480) nb_params = sum(p.numel() for p in self.parameters()) print(f"decoder: {nb_params} weights") # initialize weights self.apply(init_weights) self.sparsifier = [] self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) def sparsify(self): for sparsifier in self.sparsifier: sparsifier.step() def forward(self, z, initial_state): hidden = torch.tanh(self.hidden_init(initial_state)) gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2)) h1_state = gru_state[:,:,:96].contiguous() h2_state = gru_state[:,:,96:192].contiguous() h3_state = gru_state[:,:,192:288].contiguous() h4_state = gru_state[:,:,288:384].contiguous() h5_state = gru_state[:,:,384:].contiguous() # run decoding layer stack x = n(torch.tanh(self.dense_1(z))) x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1) x = torch.cat([x, n(self.conv1(x))], -1) x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1) x = torch.cat([x, n(self.conv2(x))], -1) x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1) x = torch.cat([x, n(self.conv3(x))], -1) x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1) x = torch.cat([x, n(self.conv4(x))], -1) x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1) x = torch.cat([x, n(self.conv5(x))], -1) # output layer and reshaping x10 = self.output(x) features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP)) return features class StatisticalModel(nn.Module): def __init__(self, quant_levels, latent_dim, state_dim): """ Statistical model for latent space Computes scaling, deadzone, r, and theta """ super(StatisticalModel, self).__init__() # copy parameters self.latent_dim = latent_dim self.state_dim = state_dim self.total_dim = latent_dim + state_dim self.quant_levels = quant_levels self.embedding_dim = 6 * self.total_dim # quantization embedding self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim) # initialize embedding to 0 with torch.no_grad(): self.quant_embedding.weight[:] = 0 def forward(self, quant_ids): """ takes quant_ids and returns statistical model parameters""" x = self.quant_embedding(quant_ids) # CAVE: theta_soft is not used anymore. Kick it out? quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim]) dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim]) theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim]) r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim]) theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim]) r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim]) return { 'quant_embedding' : x, 'quant_scale' : quant_scale, 'dead_zone' : dead_zone, 'r_hard' : r_hard, 'theta_hard' : theta_hard, 'r_soft' : r_soft, 'theta_soft' : theta_soft } class RDOVAE(nn.Module): def __init__(self, feature_dim, latent_dim, quant_levels, cond_size, cond_size2, state_dim=24, split_mode='split', clip_weights=False, pvq_num_pulses=82, state_dropout_rate=0): super(RDOVAE, self).__init__() self.feature_dim = feature_dim self.latent_dim = latent_dim self.quant_levels = quant_levels self.cond_size = cond_size self.cond_size2 = cond_size2 self.split_mode = split_mode self.state_dim = state_dim self.pvq_num_pulses = pvq_num_pulses self.state_dropout_rate = state_dropout_rate # submodules encoder and decoder share the statistical model self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim) self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim)) self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim)) self.enc_stride = CoreEncoder.FRAMES_PER_STEP self.dec_stride = CoreDecoder.FRAMES_PER_STEP if clip_weights: self.weight_clip_fn = weight_clip_factory(0.496) else: self.weight_clip_fn = None if self.dec_stride % self.enc_stride != 0: raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride") def clip_weights(self): if not type(self.weight_clip_fn) == type(None): self.apply(self.weight_clip_fn) def sparsify(self): #self.core_encoder.module.sparsify() self.core_decoder.module.sparsify() def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4): enc_stride = self.enc_stride dec_stride = self.dec_stride stride = dec_stride // enc_stride chunks = [] for offset in range(stride): # start is the smalles number = offset mod stride that decodes to a valid range start = offset while enc_stride * (start + 1) - dec_stride < 0: start += stride # check if start is a valid index if start >= z_frames: raise ValueError("get_decoder_chunks_generic: range too small") # stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride stop = z_frames - (z_frames % stride) + offset while stop < z_frames: stop += stride # calculate split points length = (stop - start) if mode == 'split': split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop] elif mode == 'random_split': split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)] else: raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}") for i in range(chunks_per_offset): # (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop) # encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j) # provided that i - j = 1 mod stride chunks.append({ 'z_start' : split_points[i], 'z_stop' : split_points[i + 1] - stride + 1, 'z_stride' : stride, 'features_start' : enc_stride * (split_points[i] + 1) - dec_stride, 'features_stop' : enc_stride * (split_points[i + 1] - stride + 1) }) return chunks def forward(self, features, q_id): # calculate statistical model from quantization ID statistical_model = self.statistical_model(q_id) # run encoder z, states = self.core_encoder(features) # scaling, dead-zone and quantization z = z * statistical_model['quant_scale'][:,:,:self.latent_dim] z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim]) # quantization z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim] #states_q = soft_pvq(states, self.pvq_num_pulses) states = states * statistical_model['quant_scale'][:,:,self.latent_dim:] states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:]) states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:] if self.state_dropout_rate > 0: drop = torch.rand(states_q.size(0)) < self.state_dropout_rate mask = torch.ones_like(states_q) mask[drop] = 0 states_q = states_q * mask # decoder chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode) outputs_hq = [] outputs_sq = [] for chunk in chunks: # decoder with hard quantized input z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) # decoder with soft quantized input z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1]) dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :] features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state) outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop'])) return { 'outputs_hard_quant' : outputs_hq, 'outputs_soft_quant' : outputs_sq, 'z' : z, 'states' : states, 'statistical_model' : statistical_model } def encode(self, features): """ encoder with quantization and rate estimation """ z, states = self.core_encoder(features) # quantization of initial states states = soft_pvq(states, self.pvq_num_pulses) state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses)) return z, states, state_size def decode(self, z, initial_state): """ decoder (flips sequences by itself) """ z_reverse = torch.flip(z, [1]) features_reverse = self.core_decoder(z_reverse, initial_state) features = torch.flip(features_reverse, [1]) return features def quantize(self, z, q_ids): """ quantization of latent vectors """ stats = self.statistical_model(q_ids) zq = z * stats['quant_scale'][:self.latent_dim] zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim]) zq = torch.round(zq) sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False) return zq, sizes def unquantize(self, zq, q_ids): """ re-scaling of latent vector """ stats = self.statistical_model(q_ids) z = zq / stats['quant_scale'][:,:,:self.latent_dim] return z def freeze_model(self): # freeze all parameters for p in self.parameters(): p.requires_grad = False for p in self.statistical_model.parameters(): p.requires_grad = True