shithub: opus

Download patch

ref: 1ca6933ac480c8584e5e47b4fea5ad2d857a336d
parent: ebccedd918f5803dd1bf3459ad3afd23c4956ea0
author: Jan Buethe <jan.buethe@gmx.net>
date: Sat Mar 8 08:03:02 EST 2025

added soft quantization to RDOVAE and FARGAN

--- a/dnn/torch/fargan/adv_train_fargan.py
+++ b/dnn/torch/fargan/adv_train_fargan.py
@@ -44,6 +44,7 @@
 model_group = parser.add_argument_group(title="model parameters")
 model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
 model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
+model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
 
 training_group = parser.add_argument_group(title="training parameters")
 training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128)
@@ -93,7 +94,7 @@
 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
 checkpoint['model_args']    = ()
-checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gamma': args.gamma}
+checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
 print(checkpoint['model_kwargs'])
 model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
 
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -1,3 +1,6 @@
+import os
+import sys
+
 import numpy as np
 import torch
 from torch import nn
@@ -7,6 +10,11 @@
 #from convert_lsp import lpc_to_lsp, lsp_to_lpc
 from rc import lpc2rc, rc2lpc
 
+source_dir = os.path.split(os.path.abspath(__file__))[0]
+sys.path.append(os.path.join(source_dir, "../dnntools"))
+from dnntools.quantization import soft_quant
+
+
 Fs = 16000
 
 fid_dict = {}
@@ -102,7 +110,7 @@
     return torch.cos(embed), torch.sin(embed)
 
 class GLU(nn.Module):
-    def __init__(self, feat_size):
+    def __init__(self, feat_size, softquant=False):
         super(GLU, self).__init__()
 
         torch.manual_seed(5)
@@ -109,6 +117,9 @@
 
         self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
 
+        if softquant:
+            self.gate = soft_quant(self.gate)
+
         self.init_weights()
 
     def init_weights(self):
@@ -125,7 +136,7 @@
         return out
 
 class FWConv(nn.Module):
-    def __init__(self, in_size, out_size, kernel_size=2):
+    def __init__(self, in_size, out_size, kernel_size=2, softquant=False):
         super(FWConv, self).__init__()
 
         torch.manual_seed(5)
@@ -133,8 +144,11 @@
         self.in_size = in_size
         self.kernel_size = kernel_size
         self.conv = weight_norm(nn.Linear(in_size*self.kernel_size, out_size, bias=False))
-        self.glu = GLU(out_size)
+        self.glu = GLU(out_size, softquant=softquant)
 
+        if softquant:
+            self.conv = soft_quant(self.conv)
+
         self.init_weights()
 
     def init_weights(self):
@@ -154,7 +168,7 @@
     return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
 
 class FARGANCond(nn.Module):
-    def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
+    def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12, softquant=False):
         super(FARGANCond, self).__init__()
 
         self.feature_dim = feature_dim
@@ -165,6 +179,10 @@
         self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
         self.fdense2 = nn.Linear(128, 80*4, bias=False)
 
+        if softquant:
+            self.fconv1 = soft_quant(self.fconv1)
+            self.fdense2 = soft_quant(self.fdense2)
+
         self.apply(init_weights)
         nb_params = sum(p.numel() for p in self.parameters())
         print(f"cond model: {nb_params} weights")
@@ -183,7 +201,7 @@
         return tmp
 
 class FARGANSub(nn.Module):
-    def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256):
+    def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256, softquant=False):
         super(FARGANSub, self).__init__()
 
         self.subframe_size = subframe_size
@@ -192,15 +210,15 @@
         self.cond_gain_dense = nn.Linear(80, 1)
 
         #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
-        self.fwc0 = FWConv(2*self.subframe_size+80+4, 192)
+        self.fwc0 = FWConv(2*self.subframe_size+80+4, 192, softquant=softquant)
         self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
         self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
         self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
 
-        self.gru1_glu = GLU(160)
-        self.gru2_glu = GLU(128)
-        self.gru3_glu = GLU(128)
-        self.skip_glu = GLU(128)
+        self.gru1_glu = GLU(160, softquant=softquant)
+        self.gru2_glu = GLU(128, softquant=softquant)
+        self.gru3_glu = GLU(128, softquant=softquant)
+        self.skip_glu = GLU(128, softquant=softquant)
         #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
 
         self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
@@ -207,6 +225,12 @@
         self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
         self.gain_dense_out = nn.Linear(192, 4)
 
+        if softquant:
+            self.gru1 = soft_quant(self.gru1, names=['weight_hh', 'weight_ih'])
+            self.gru2 = soft_quant(self.gru2, names=['weight_hh', 'weight_ih'])
+            self.gru3 = soft_quant(self.gru3, names=['weight_hh', 'weight_ih'])
+            self.skip_dense = soft_quant(self.skip_dense)
+            self.sig_dense_out = soft_quant(self.sig_dense_out)
 
         self.apply(init_weights)
         nb_params = sum(p.numel() for p in self.parameters())
@@ -271,7 +295,7 @@
         return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
 
 class FARGAN(nn.Module):
-    def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
+    def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None, softquant=False):
         super(FARGAN, self).__init__()
 
         self.subframe_size = subframe_size
@@ -280,8 +304,8 @@
         self.feature_dim = feature_dim
         self.cond_size = cond_size
 
-        self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size)
-        self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size)
+        self.cond_net = FARGANCond(feature_dim=feature_dim, cond_size=cond_size, softquant=softquant)
+        self.sig_net = FARGANSub(subframe_size=subframe_size, nb_subframes=nb_subframes, cond_size=cond_size, softquant=softquant)
 
     def forward(self, features, period, nb_frames, pre=None, states=None):
         device = features.device
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -25,6 +25,7 @@
 model_group = parser.add_argument_group(title="model parameters")
 model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
 model_group.add_argument('--gamma', type=float, help="Use A(z/gamma), default: 0.9", default=0.9)
+model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
 
 training_group = parser.add_argument_group(title="training parameters")
 training_group.add_argument('--batch-size', type=int, help="batch size, default: 512", default=512)
@@ -72,7 +73,7 @@
 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
 checkpoint['model_args']    = ()
-checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gamma': args.gamma}
+checkpoint['model_kwargs']  = {'cond_size': cond_size, 'gamma': args.gamma, 'softquant': args.softquant}
 print(checkpoint['model_kwargs'])
 model = fargan.FARGAN(*checkpoint['model_args'], **checkpoint['model_kwargs'])
 
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -40,6 +40,8 @@
 sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
 from utils.sparsification import GRUSparsifier
 from torch.nn.utils import weight_norm
+sys.path.append(os.path.join(source_dir, "../../dnntools"))
+from dnntools.quantization import soft_quant
 
 # Quantization and rate related utily functions
 
@@ -260,12 +262,16 @@
 
 
 class MyConv(nn.Module):
-    def __init__(self, input_dim, output_dim, dilation=1):
+    def __init__(self, input_dim, output_dim, dilation=1, softquant=False):
         super(MyConv, self).__init__()
         self.input_dim = input_dim
         self.output_dim = output_dim
         self.dilation=dilation
         self.conv = nn.Conv1d(input_dim, output_dim, kernel_size=2, padding='valid', dilation=dilation)
+
+        if softquant:
+            self.conv = soft_quant(self.conv)
+
     def forward(self, x, state=None):
         device = x.device
         conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
@@ -272,7 +278,7 @@
         return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
 
 class GLU(nn.Module):
-    def __init__(self, feat_size):
+    def __init__(self, feat_size, softquant=False):
         super(GLU, self).__init__()
 
         torch.manual_seed(5)
@@ -279,6 +285,9 @@
 
         self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
 
+        if softquant:
+            self.gate = soft_quant(self.gate)
+
         self.init_weights()
 
     def init_weights(self):
@@ -299,7 +308,7 @@
     FRAMES_PER_STEP = 2
     CONV_KERNEL_SIZE = 4
 
-    def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24):
+    def __init__(self, feature_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False):
         """ core encoder for RDOVAE
 
             Computes latents, initial states, and rate estimates from features and lambda parameter
@@ -321,15 +330,15 @@
         # layers
         self.dense_1 = nn.Linear(self.input_dim, 64)
         self.gru1 = nn.GRU(64, 64, batch_first=True)
-        self.conv1 = MyConv(128, 96)
+        self.conv1 = MyConv(128, 96, softquant=True)
         self.gru2 = nn.GRU(224, 64, batch_first=True)
-        self.conv2 = MyConv(288, 96, dilation=2)
+        self.conv2 = MyConv(288, 96, dilation=2, softquant=True)
         self.gru3 = nn.GRU(384, 64, batch_first=True)
-        self.conv3 = MyConv(448, 96, dilation=2)
+        self.conv3 = MyConv(448, 96, dilation=2, softquant=True)
         self.gru4 = nn.GRU(544, 64, batch_first=True)
-        self.conv4 = MyConv(608, 96, dilation=2)
+        self.conv4 = MyConv(608, 96, dilation=2, softquant=True)
         self.gru5 = nn.GRU(704, 64, batch_first=True)
-        self.conv5 = MyConv(768, 96, dilation=2)
+        self.conv5 = MyConv(768, 96, dilation=2, softquant=True)
 
         self.z_dense = nn.Linear(864, self.output_dim)
 
@@ -343,7 +352,17 @@
         # initialize weights
         self.apply(init_weights)
 
+        if softquant:
+            self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.z_dense = soft_quant(self.z_dense)
+            self.state_dense_1 = soft_quant(self.state_dense_1)
+            self.state_dense_2 = soft_quant(self.state_dense_2)
 
+
     def forward(self, features):
 
         # reshape features
@@ -379,7 +398,7 @@
 
     FRAMES_PER_STEP = 4
 
-    def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24):
+    def __init__(self, input_dim, output_dim, cond_size, cond_size2, state_size=24, softquant=False):
         """ core decoder for RDOVAE
 
             Computes features from latents, initial state, and quantization index
@@ -400,21 +419,21 @@
         # layers
         self.dense_1    = nn.Linear(self.input_size, 96)
         self.gru1 = nn.GRU(96, 96, batch_first=True)
-        self.conv1 = MyConv(192, 32)
+        self.conv1 = MyConv(192, 32, softquant=softquant)
         self.gru2 = nn.GRU(224, 96, batch_first=True)
-        self.conv2 = MyConv(320, 32)
+        self.conv2 = MyConv(320, 32, softquant=softquant)
         self.gru3 = nn.GRU(352, 96, batch_first=True)
-        self.conv3 = MyConv(448, 32)
+        self.conv3 = MyConv(448, 32, softquant=softquant)
         self.gru4 = nn.GRU(480, 96, batch_first=True)
-        self.conv4 = MyConv(576, 32)
+        self.conv4 = MyConv(576, 32, softquant=softquant)
         self.gru5 = nn.GRU(608, 96, batch_first=True)
-        self.conv5 = MyConv(704, 32)
+        self.conv5 = MyConv(704, 32, softquant=softquant)
         self.output  = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
-        self.glu1 = GLU(96)
-        self.glu2 = GLU(96)
-        self.glu3 = GLU(96)
-        self.glu4 = GLU(96)
-        self.glu5 = GLU(96)
+        self.glu1 = GLU(96, softquant=softquant)
+        self.glu2 = GLU(96, softquant=softquant)
+        self.glu3 = GLU(96, softquant=softquant)
+        self.glu4 = GLU(96, softquant=softquant)
+        self.glu5 = GLU(96, softquant=softquant)
         self.hidden_init = nn.Linear(self.state_size, 128)
         self.gru_init = nn.Linear(128, 480)
 
@@ -429,6 +448,15 @@
         self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
         self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
 
+        if softquant:
+            self.gru1 = soft_quant(self.gru1, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru2 = soft_quant(self.gru2, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru3 = soft_quant(self.gru3, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru4 = soft_quant(self.gru4, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.gru5 = soft_quant(self.gru5, names=['weight_hh_l0', 'weight_ih_l0'])
+            self.output = soft_quant(self.output)
+            self.gru_init = soft_quant(self.gru_init)
+
     def sparsify(self):
         for sparsifier in self.sparsifier:
             sparsifier.step()
@@ -525,7 +553,8 @@
                  split_mode='split',
                  clip_weights=False,
                  pvq_num_pulses=82,
-                 state_dropout_rate=0):
+                 state_dropout_rate=0,
+                 softquant=False):
 
         super(RDOVAE, self).__init__()
 
@@ -541,8 +570,8 @@
 
         # submodules encoder and decoder share the statistical model
         self.statistical_model = StatisticalModel(quant_levels, latent_dim, state_dim)
-        self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim))
-        self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim))
+        self.core_encoder = nn.DataParallel(CoreEncoder(feature_dim, latent_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant))
+        self.core_decoder = nn.DataParallel(CoreDecoder(latent_dim, feature_dim, cond_size, cond_size2, state_size=state_dim, softquant=softquant))
 
         self.enc_stride = CoreEncoder.FRAMES_PER_STEP
         self.dec_stride = CoreDecoder.FRAMES_PER_STEP
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -54,6 +54,7 @@
 model_group.add_argument('--lambda-max', type=float, help="maximal value for rate lambda, default: 0.0104", default=0.0104)
 model_group.add_argument('--pvq-num-pulses', type=int, help="number of pulses for PVQ, default: 82", default=82)
 model_group.add_argument('--state-dropout-rate', type=float, help="state dropout rate, default: 0", default=0.0)
+model_group.add_argument('--softquant', action="store_true", help="enables soft quantization during training")
 
 training_group = parser.add_argument_group(title="training parameters")
 training_group.add_argument('--batch-size', type=int, help="batch size, default: 32", default=32)
@@ -109,6 +110,7 @@
 lambda_min = args.lambda_min
 lambda_max = args.lambda_max
 state_dim = args.state_dim
+softquant = args.softquant
 # not expsed
 num_features = 20
 
@@ -118,7 +120,7 @@
 
 # model
 checkpoint['model_args']    = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
-checkpoint['model_kwargs']  = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate}
+checkpoint['model_kwargs']  = {'state_dim': state_dim, 'split_mode' : split_mode, 'pvq_num_pulses': args.pvq_num_pulses, 'state_dropout_rate': args.state_dropout_rate, 'softquant': softquant}
 model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
 
 if type(args.initial_checkpoint) != type(None):
--