ref: 0fed741a87ccc061eff382d306fadd71acdfc57d
dir: /dnn/torch/rdovae/rdovae/rdovae.py/
"""
/* Copyright (c) 2022 Amazon
Written by Jan Buethe */
/*
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
"""
""" Pytorch implementations of rate distortion optimized variational autoencoder """
import math as m
import torch
from torch import nn
import torch.nn.functional as F
import sys
import os
source_dir = os.path.split(os.path.abspath(__file__))[0]
sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
from utils.sparsification import GRUSparsifier
from torch.nn.utils import weight_norm
# Quantization and rate related utily functions
def soft_pvq(x, k):
""" soft pyramid vector quantizer """
# L2 normalization
x_norm2 = x / (1e-15 + torch.norm(x, dim=-1, keepdim=True))
with torch.no_grad():
# quantization loop, no need to track gradients here
x_norm1 = x / torch.sum(torch.abs(x), dim=-1, keepdim=True)
# set initial scaling factor to k
scale_factor = k
x_scaled = scale_factor * x_norm1
x_quant = torch.round(x_scaled)
# we aim for ||x_quant||_L1 = k
for _ in range(10):
# remove signs and calculate L1 norm
abs_x_quant = torch.abs(x_quant)
abs_x_scaled = torch.abs(x_scaled)
l1_x_quant = torch.sum(abs_x_quant, axis=-1)
# increase, where target is too small and decrease, where target is too large
plus = 1.0001 * torch.min((abs_x_quant + 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
minus = 0.9999 * torch.max((abs_x_quant - 0.5) / (abs_x_scaled + 1e-15), dim=-1).values
factor = torch.where(l1_x_quant > k, minus, plus)
factor = torch.where(l1_x_quant == k, torch.ones_like(factor), factor)
scale_factor = scale_factor * factor.unsqueeze(-1)
# update x
x_scaled = scale_factor * x_norm1
x_quant = torch.round(x_quant)
# L2 normalization of quantized x
x_quant_norm2 = x_quant / (1e-15 + torch.norm(x_quant, dim=-1, keepdim=True))
quantization_error = x_quant_norm2 - x_norm2
return x_norm2 + quantization_error.detach()
def cache_parameters(func):
cache = dict()
def cached_func(*args):
if args in cache:
return cache[args]
else:
cache[args] = func(*args)
return cache[args]
return cached_func
@cache_parameters
def pvq_codebook_size(n, k):
if k == 0:
return 1
if n == 0:
return 0
return pvq_codebook_size(n - 1, k) + pvq_codebook_size(n, k - 1) + pvq_codebook_size(n - 1, k - 1)
def soft_rate_estimate(z, r, reduce=True):
""" rate approximation with dependent theta Eq. (7)"""
rate = torch.sum(
- torch.log2((1 - r)/(1 + r) * r ** torch.abs(z) + 1e-6),
dim=-1
)
if reduce:
rate = torch.mean(rate)
return rate
def hard_rate_estimate(z, r, theta, reduce=True):
""" hard rate approximation """
z_q = torch.round(z)
p0 = 1 - r ** (0.5 + 0.5 * theta)
alpha = torch.relu(1 - torch.abs(z_q)) ** 2
rate = - torch.sum(
(alpha * torch.log2(p0 * r ** torch.abs(z_q) + 1e-6)
+ (1 - alpha) * torch.log2(0.5 * (1 - p0) * (1 - r) * r ** (torch.abs(z_q) - 1) + 1e-6)),
dim=-1
)
if reduce:
rate = torch.mean(rate)
return rate
def soft_dead_zone(x, dead_zone):
""" approximates application of a dead zone to x """
d = dead_zone * 0.05
return x - d * torch.tanh(x / (0.1 + d))
def hard_quantize(x):
""" round with copy gradient trick """
return x + (torch.round(x) - x).detach()
def noise_quantize(x):
""" simulates quantization with addition of random uniform noise """
return x + (torch.rand_like(x) - 0.5)
# loss functions
def distortion_loss(y_true, y_pred, rate_lambda=None):
""" custom distortion loss for LPCNet features """
if y_true.size(-1) != 20:
raise ValueError('distortion loss is designed to work with 20 features')
ceps_error = y_pred[..., :18] - y_true[..., :18]
pitch_error = 2*(y_pred[..., 18:19] - y_true[..., 18:19])
corr_error = y_pred[..., 19:] - y_true[..., 19:]
pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
loss = torch.mean(ceps_error ** 2 + (10/18) * torch.abs(pitch_error) * pitch_weight + (1/18) * corr_error ** 2, dim=-1)
if type(rate_lambda) != type(None):
loss = loss / torch.sqrt(rate_lambda)
loss = torch.mean(loss)
return loss
# sampling functions
import random
def random_split(start, stop, num_splits=3, min_len=3):
get_min_len = lambda x : min([x[i+1] - x[i] for i in range(len(x) - 1)])
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
while get_min_len(candidate) < min_len:
candidate = [start] + sorted([random.randint(start, stop-1) for i in range(num_splits)]) + [stop]
return candidate
# weight initialization and clipping
def init_weights(module):
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith('weight_hh_'):
nn.init.orthogonal_(p[1])
def weight_clip_factory(max_value):
""" weight clipping function concerning sum of abs values of adjecent weights """
def clip_weight_(w):
stop = w.size(1)
# omit last column if stop is odd
if stop % 2:
stop -= 1
max_values = max_value * torch.ones_like(w[:, :stop])
factor = max_value / torch.maximum(max_values,
torch.repeat_interleave(
torch.abs(w[:, :stop:2]) + torch.abs(w[:, 1:stop:2]),
2,
1))
with torch.no_grad():
w[:, :stop] *= factor
def clip_weights(module):
if isinstance(module, nn.GRU) or isinstance(module, nn.Linear):
for name, w in module.named_parameters():
if name.startswith('weight'):
clip_weight_(w)
return clip_weights
def n(x):
return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
# RDOVAE module and submodules
sparsify_start = 12000
sparsify_stop = 24000
sparsify_interval = 100
sparsify_exponent = 3
#sparsify_start = 0
#sparsify_stop = 0
sparse_params1 = {
# 'W_hr' : (1.0, [8, 4], True),
# 'W_hz' : (1.0, [8, 4], True),
# 'W_hn' : (1.0, [8, 4], True),
'W_ir' : (0.6, [8, 4], False),
'W_iz' : (0.4, [8, 4], False),
'W_in' : (0.8, [8, 4], False)
}
sparse_params2 = {
# 'W_hr' : (1.0, [8, 4], True),
# 'W_hz' : (1.0, [8, 4], True),
# 'W_hn' : (1.0, [8, 4], True),
'W_ir' : (0.3, [8, 4], False),
'W_iz' : (0.2, [8, 4], False),
'W_in' : (0.4, [8, 4], False)
}
class MyConv(nn.Module):
def __init__(self, input_dim, output_dim, dilation=1):
super(MyConv, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dilation=dilation
self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
def forward(self, x, state=None):
device = x.device
conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
class GLU(nn.Module):
def __init__(self, feat_size):
super(GLU, self).__init__()
torch.manual_seed(5)
self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
nn.init.orthogonal_(m.weight.data)
def forward(self, x):
out = x * torch.sigmoid(self.gate(x))
return out
class CoreEncoder(nn.Module):
STATE_HIDDEN = 128
FRAMES_PER_STEP = 2
CONV_KERNEL_SIZE = 4
def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
""" core encoder for RDOVAE
Computes latents, initial states, and rate estimates from features and lambda parameter
"""
super(CoreEncoder, self).__init__()
# hyper parameters
self.feature_dim = feature_dim
self.output_dim = output_dim
self.cond_size = cond_size
self.cond_size2 = cond_size2
self.state_size = state_size
# derived parameters
self.input_dim = self.FRAMES_PER_STEP * self.feature_dim
# layers
self.dense_1 = nn.Linear(self.input_dim, 64)
self.gru1 = nn.GRU(64, 64, batch_first=True)
self.conv1 = MyConv(128, 96)
self.gru2 = nn.GRU(224, 64, batch_first=True)
self.conv2 = MyConv(288, 96, dilation=2)
self.gru3 = nn.GRU(384, 64, batch_first=True)
self.conv3 = MyConv(448, 96, dilation=2)
self.gru4 = nn.GRU(544, 64, batch_first=True)
self.conv4 = MyConv(608, 96, dilation=2)
self.gru5 = nn.GRU(704, 64, batch_first=True)
self.conv5 = MyConv(768, 96, dilation=2)
self.z_dense = nn.Linear(864, self.output_dim)
self.state_dense_1 = nn.Linear(864, self.STATE_HIDDEN)
self.state_dense_2 = nn.Linear(self.STATE_HIDDEN, self.state_size)
nb_params = sum(p.numel() for p in self.parameters())
print(f"encoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
def forward(self, features):
# reshape features
x = torch.reshape(features, (features.size(0), features.size(1) // self.FRAMES_PER_STEP, self.FRAMES_PER_STEP * features.size(2)))
batch = x.size(0)
device = x.device
# run encoding layer stack
x = n(torch.tanh(self.dense_1(x)))
x = torch.cat([x, n(self.gru1(x)[0])], -1)
x = torch.cat([x, n(self.conv1(x))], -1)
x = torch.cat([x, n(self.gru2(x)[0])], -1)
x = torch.cat([x, n(self.conv2(x))], -1)
x = torch.cat([x, n(self.gru3(x)[0])], -1)
x = torch.cat([x, n(self.conv3(x))], -1)
x = torch.cat([x, n(self.gru4(x)[0])], -1)
x = torch.cat([x, n(self.conv4(x))], -1)
x = torch.cat([x, n(self.gru5(x)[0])], -1)
x = torch.cat([x, n(self.conv5(x))], -1)
z = self.z_dense(x)
# init state for decoder
states = torch.tanh(self.state_dense_1(x))
states = self.state_dense_2(states)
return z, states
class CoreDecoder(nn.Module):
FRAMES_PER_STEP = 4
def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
""" core decoder for RDOVAE
Computes features from latents, initial state, and quantization index
"""
super(CoreDecoder, self).__init__()
# hyper parameters
self.input_dim = input_dim
self.output_dim = output_dim
self.cond_size = cond_size
self.cond_size2 = cond_size2
self.state_size = state_size
self.input_size = self.input_dim
# layers
self.dense_1 = nn.Linear(self.input_size, 96)
self.gru1 = nn.GRU(96, 96, batch_first=True)
self.conv1 = MyConv(192, 32)
self.gru2 = nn.GRU(224, 96, batch_first=True)
self.conv2 = MyConv(320, 32)
self.gru3 = nn.GRU(352, 96, batch_first=True)
self.conv3 = MyConv(448, 32)
self.gru4 = nn.GRU(480, 96, batch_first=True)
self.conv4 = MyConv(576, 32)
self.gru5 = nn.GRU(608, 96, batch_first=True)
self.conv5 = MyConv(704, 32)
self.output = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
self.glu1 = GLU(96)
self.glu2 = GLU(96)
self.glu3 = GLU(96)
self.glu4 = GLU(96)
self.glu5 = GLU(96)
self.hidden_init = nn.Linear(self.state_size, 128)
self.gru_init = nn.Linear(128, 480)
nb_params = sum(p.numel() for p in self.parameters())
print(f"decoder: {nb_params} weights")
# initialize weights
self.apply(init_weights)
self.sparsifier = []
self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def forward(self, z, initial_state):
hidden = torch.tanh(self.hidden_init(initial_state))
gru_state = torch.tanh(self.gru_init(hidden).permute(1, 0, 2))
h1_state = gru_state[:,:,:96].contiguous()
h2_state = gru_state[:,:,96:192].contiguous()
h3_state = gru_state[:,:,192:288].contiguous()
h4_state = gru_state[:,:,288:384].contiguous()
h5_state = gru_state[:,:,384:].contiguous()
# run decoding layer stack
x = n(torch.tanh(self.dense_1(z)))
x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
x = torch.cat([x, n(self.conv1(x))], -1)
x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
x = torch.cat([x, n(self.conv2(x))], -1)
x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
x = torch.cat([x, n(self.conv3(x))], -1)
x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
x = torch.cat([x, n(self.conv4(x))], -1)
x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
x = torch.cat([x, n(self.conv5(x))], -1)
# output layer and reshaping
x10 = self.output(x)
features = torch.reshape(x10, (x10.size(0), x10.size(1) * self.FRAMES_PER_STEP, x10.size(2) // self.FRAMES_PER_STEP))
return features
class StatisticalModel(nn.Module):
def __init__(self, quant_levels, latent_dim, state_dim):
""" Statistical model for latent space
Computes scaling, deadzone, r, and theta
"""
super(StatisticalModel, self).__init__()
# copy parameters
self.latent_dim = latent_dim
self.state_dim = state_dim
self.total_dim = latent_dim + state_dim
self.quant_levels = quant_levels
self.embedding_dim = 6 * self.total_dim
# quantization embedding
self.quant_embedding = nn.Embedding(quant_levels, self.embedding_dim)
# initialize embedding to 0
with torch.no_grad():
self.quant_embedding.weight[:] = 0
def forward(self, quant_ids):
""" takes quant_ids and returns statistical model parameters"""
x = self.quant_embedding(quant_ids)
# CAVE: theta_soft is not used anymore. Kick it out?
quant_scale = F.softplus(x[..., 0 * self.total_dim : 1 * self.total_dim])
dead_zone = F.softplus(x[..., 1 * self.total_dim : 2 * self.total_dim])
theta_soft = torch.sigmoid(x[..., 2 * self.total_dim : 3 * self.total_dim])
r_soft = torch.sigmoid(x[..., 3 * self.total_dim : 4 * self.total_dim])
theta_hard = torch.sigmoid(x[..., 4 * self.total_dim : 5 * self.total_dim])
r_hard = torch.sigmoid(x[..., 5 * self.total_dim : 6 * self.total_dim])
return {
'quant_embedding' : x,
'quant_scale' : quant_scale,
'dead_zone' : dead_zone,
'r_hard' : r_hard,
'theta_hard' : theta_hard,
'r_soft' : r_soft,
'theta_soft' : theta_soft
}
class RDOVAE(nn.Module):
def __init__(self,
feature_dim,
latent_dim,
quant_levels,
cond_size,
cond_size2,
state_dim=24,
split_mode='split',
clip_weights=False,
pvq_num_pulses=82,
state_dropout_rate=0):
super(RDOVAE, self).__init__()
self.feature_dim = feature_dim
self.latent_dim = latent_dim
self.quant_levels = quant_levels
self.cond_size = cond_size
self.cond_size2 = cond_size2
self.split_mode = split_mode
self.state_dim = state_dim
self.pvq_num_pulses = pvq_num_pulses
self.state_dropout_rate = state_dropout_rate
# submodules encoder and decoder share the statistical model
self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
self.enc_stride = CoreEncoder.FRAMES_PER_STEP
self.dec_stride = CoreDecoder.FRAMES_PER_STEP
if clip_weights:
self.weight_clip_fn = weight_clip_factory(0.496)
else:
self.weight_clip_fn = None
if self.dec_stride % self.enc_stride != 0:
raise ValueError(f"get_decoder_chunks_generic: encoder stride does not divide decoder stride")
def clip_weights(self):
if not type(self.weight_clip_fn) == type(None):
self.apply(self.weight_clip_fn)
def sparsify(self):
#self.core_encoder.module.sparsify()
self.core_decoder.module.sparsify()
def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
enc_stride = self.enc_stride
dec_stride = self.dec_stride
stride = dec_stride // enc_stride
chunks = []
for offset in range(stride):
# start is the smalles number = offset mod stride that decodes to a valid range
start = offset
while enc_stride * (start + 1) - dec_stride < 0:
start += stride
# check if start is a valid index
if start >= z_frames:
raise ValueError("get_decoder_chunks_generic: range too small")
# stop is the smallest number outside [0, num_enc_frames] that's congruent to offset mod stride
stop = z_frames - (z_frames % stride) + offset
while stop < z_frames:
stop += stride
# calculate split points
length = (stop - start)
if mode == 'split':
split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
elif mode == 'random_split':
split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
else:
raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
for i in range(chunks_per_offset):
# (enc_frame_start, enc_frame_stop, enc_frame_stride, stride, feature_frame_start, feature_frame_stop)
# encoder range(i, j, stride) maps to feature range(enc_stride * (i + 1) - dec_stride, enc_stride * j)
# provided that i - j = 1 mod stride
chunks.append({
'z_start' : split_points[i],
'z_stop' : split_points[i + 1] - stride + 1,
'z_stride' : stride,
'features_start' : enc_stride * (split_points[i] + 1) - dec_stride,
'features_stop' : enc_stride * (split_points[i + 1] - stride + 1)
})
return chunks
def forward(self, features, q_id):
# calculate statistical model from quantization ID
statistical_model = self.statistical_model(q_id)
# run encoder
z, states = self.core_encoder(features)
# scaling, dead-zone and quantization
z = z * statistical_model['quant_scale'][:,:,:self.latent_dim]
z = soft_dead_zone(z, statistical_model['dead_zone'][:,:,:self.latent_dim])
# quantization
z_q = hard_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
z_n = noise_quantize(z) / statistical_model['quant_scale'][:,:,:self.latent_dim]
#states_q = soft_pvq(states, self.pvq_num_pulses)
states = states * statistical_model['quant_scale'][:,:,self.latent_dim:]
states = soft_dead_zone(states, statistical_model['dead_zone'][:,:,self.latent_dim:])
states_q = hard_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
states_n = noise_quantize(states) / statistical_model['quant_scale'][:,:,self.latent_dim:]
if self.state_dropout_rate > 0:
drop = torch.rand(states_q.size(0)) < self.state_dropout_rate
mask = torch.ones_like(states_q)
mask[drop] = 0
states_q = states_q * mask
# decoder
chunks = self.get_decoder_chunks(z.size(1), mode=self.split_mode)
outputs_hq = []
outputs_sq = []
for chunk in chunks:
# decoder with hard quantized input
z_dec_reverse = torch.flip(z_q[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
dec_initial_state = states_q[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_hq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
# decoder with soft quantized input
z_dec_reverse = torch.flip(z_n[..., chunk['z_start'] : chunk['z_stop'] : chunk['z_stride'], :], [1])
dec_initial_state = states_n[..., chunk['z_stop'] - 1 : chunk['z_stop'], :]
features_reverse = self.core_decoder(z_dec_reverse, dec_initial_state)
outputs_sq.append((torch.flip(features_reverse, [1]), chunk['features_start'], chunk['features_stop']))
return {
'outputs_hard_quant' : outputs_hq,
'outputs_soft_quant' : outputs_sq,
'z' : z,
'states' : states,
'statistical_model' : statistical_model
}
def encode(self, features):
""" encoder with quantization and rate estimation """
z, states = self.core_encoder(features)
# quantization of initial states
states = soft_pvq(states, self.pvq_num_pulses)
state_size = m.log2(pvq_codebook_size(self.state_dim, self.pvq_num_pulses))
return z, states, state_size
def decode(self, z, initial_state):
""" decoder (flips sequences by itself) """
z_reverse = torch.flip(z, [1])
features_reverse = self.core_decoder(z_reverse, initial_state)
features = torch.flip(features_reverse, [1])
return features
def quantize(self, z, q_ids):
""" quantization of latent vectors """
stats = self.statistical_model(q_ids)
zq = z * stats['quant_scale'][:self.latent_dim]
zq = soft_dead_zone(zq, stats['dead_zone'][:self.latent_dim])
zq = torch.round(zq)
sizes = hard_rate_estimate(zq, stats['r_hard'][:,:,:self.latent_dim], stats['theta_hard'][:,:,:self.latent_dim], reduce=False)
return zq, sizes
def unquantize(self, zq, q_ids):
""" re-scaling of latent vector """
stats = self.statistical_model(q_ids)
z = zq / stats['quant_scale'][:,:,:self.latent_dim]
return z
def freeze_model(self):
# freeze all parameters
for p in self.parameters():
p.requires_grad = False
for p in self.statistical_model.parameters():
p.requires_grad = True