shithub: opus

Download patch

ref: b0620c0bf9864d9b18ead6b4bb6e0800542a931d
parent: 58923f61c26ac0f5d8284d427344466e3bc2c674
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Nov 14 23:08:50 EST 2023

Using sparse GRUs in DRED decoder

Saves ~270 kB of weights in the decoder

--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
 srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
-dnn/download_model.sh b6095cf
+dnn/download_model.sh 58923f6
 
 echo "Updating build configuration files, please wait...."
 
--- a/dnn/dred_rdovae_dec.c
+++ b/dnn/dred_rdovae_dec.c
@@ -98,7 +98,7 @@
     output_index += DEC_DENSE1_OUT_SIZE;
 
     compute_generic_gru(&model->dec_gru1_input, &model->dec_gru1_recurrent, dec_state->gru1_state, buffer);
-    OPUS_COPY(&buffer[output_index], dec_state->gru1_state, DEC_GRU1_OUT_SIZE);
+    compute_glu(&model->dec_glu1, &buffer[output_index], dec_state->gru1_state);
     output_index += DEC_GRU1_OUT_SIZE;
     conv1_cond_init(dec_state->conv1_state, output_index, 1, &dec_state->initialized);
     compute_generic_conv1d(&model->dec_conv1, &buffer[output_index], dec_state->conv1_state, buffer, output_index, ACTIVATION_TANH);
@@ -105,7 +105,7 @@
     output_index += DEC_CONV1_OUT_SIZE;
 
     compute_generic_gru(&model->dec_gru2_input, &model->dec_gru2_recurrent, dec_state->gru2_state, buffer);
-    OPUS_COPY(&buffer[output_index], dec_state->gru2_state, DEC_GRU2_OUT_SIZE);
+    compute_glu(&model->dec_glu2, &buffer[output_index], dec_state->gru2_state);
     output_index += DEC_GRU2_OUT_SIZE;
     conv1_cond_init(dec_state->conv2_state, output_index, 1, &dec_state->initialized);
     compute_generic_conv1d(&model->dec_conv2, &buffer[output_index], dec_state->conv2_state, buffer, output_index, ACTIVATION_TANH);
@@ -112,7 +112,7 @@
     output_index += DEC_CONV2_OUT_SIZE;
 
     compute_generic_gru(&model->dec_gru3_input, &model->dec_gru3_recurrent, dec_state->gru3_state, buffer);
-    OPUS_COPY(&buffer[output_index], dec_state->gru3_state, DEC_GRU3_OUT_SIZE);
+    compute_glu(&model->dec_glu3, &buffer[output_index], dec_state->gru3_state);
     output_index += DEC_GRU3_OUT_SIZE;
     conv1_cond_init(dec_state->conv3_state, output_index, 1, &dec_state->initialized);
     compute_generic_conv1d(&model->dec_conv3, &buffer[output_index], dec_state->conv3_state, buffer, output_index, ACTIVATION_TANH);
@@ -119,7 +119,7 @@
     output_index += DEC_CONV3_OUT_SIZE;
 
     compute_generic_gru(&model->dec_gru4_input, &model->dec_gru4_recurrent, dec_state->gru4_state, buffer);
-    OPUS_COPY(&buffer[output_index], dec_state->gru4_state, DEC_GRU4_OUT_SIZE);
+    compute_glu(&model->dec_glu4, &buffer[output_index], dec_state->gru4_state);
     output_index += DEC_GRU4_OUT_SIZE;
     conv1_cond_init(dec_state->conv4_state, output_index, 1, &dec_state->initialized);
     compute_generic_conv1d(&model->dec_conv4, &buffer[output_index], dec_state->conv4_state, buffer, output_index, ACTIVATION_TANH);
@@ -126,7 +126,7 @@
     output_index += DEC_CONV4_OUT_SIZE;
 
     compute_generic_gru(&model->dec_gru5_input, &model->dec_gru5_recurrent, dec_state->gru5_state, buffer);
-    OPUS_COPY(&buffer[output_index], dec_state->gru5_state, DEC_GRU5_OUT_SIZE);
+    compute_glu(&model->dec_glu5, &buffer[output_index], dec_state->gru5_state);
     output_index += DEC_GRU5_OUT_SIZE;
     conv1_cond_init(dec_state->conv5_state, output_index, 1, &dec_state->initialized);
     compute_generic_conv1d(&model->dec_conv5, &buffer[output_index], dec_state->conv5_state, buffer, output_index, ACTIVATION_TANH);
--- a/dnn/torch/lpcnet/utils/sparsification/common.py
+++ b/dnn/torch/lpcnet/utils/sparsification/common.py
@@ -29,7 +29,7 @@
 
 import torch
 
-def sparsify_matrix(matrix : torch.tensor, density : float, block_size : list[int, int], keep_diagonal : bool=False, return_mask : bool=False):
+def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
     """ sparsifies matrix with specified block size
 
         Parameters:
@@ -118,4 +118,4 @@
     # activations estimated by 10 flops per activation
     flops += 30 * hidden_size
 
-    return flops
\ No newline at end of file
+    return flops
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -225,10 +225,15 @@
 
     # decoder
     decoder_dense_layers = [
-        ('core_decoder.module.dense_1'       , 'dec_dense1',   'TANH', False),
-        ('core_decoder.module.output'       , 'dec_output',   'LINEAR', True),
+        ('core_decoder.module.dense_1'      , 'dec_dense1',  'TANH', False),
+        ('core_decoder.module.glu1.gate'    , 'dec_glu1',    'TANH', True),
+        ('core_decoder.module.glu2.gate'    , 'dec_glu2',    'TANH', True),
+        ('core_decoder.module.glu3.gate'    , 'dec_glu3',    'TANH', True),
+        ('core_decoder.module.glu4.gate'    , 'dec_glu4',    'TANH', True),
+        ('core_decoder.module.glu5.gate'    , 'dec_glu5',    'TANH', True),
+        ('core_decoder.module.output'       , 'dec_output',  'LINEAR', True),
         ('core_decoder.module.hidden_init'  , 'dec_hidden_init',        'TANH', False),
-        ('core_decoder.module.gru_init'    , 'dec_gru_init',        'TANH', True),
+        ('core_decoder.module.gru_init'     , 'dec_gru_init','TANH', True),
     ]
 
     for name, export_name, _, quantize in decoder_dense_layers:
@@ -338,6 +343,13 @@
     checkpoint = torch.load(args.checkpoint, map_location='cpu')
     model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
     missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
+    def _remove_weight_norm(m):
+        try:
+            torch.nn.utils.remove_weight_norm(m)
+        except ValueError:  # this module didn't have weight norm
+            return
+    model.apply(_remove_weight_norm)
+
 
     if len(missing_keys) > 0:
         raise ValueError(f"error: missing keys in state dict")
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -34,6 +34,12 @@
 import torch
 from torch import nn
 import torch.nn.functional as F
+import sys
+import os
+source_dir = os.path.split(os.path.abspath(__file__))[0]
+sys.path.append(os.path.join(source_dir, "../../lpcnet/"))
+from utils.sparsification import GRUSparsifier
+from torch.nn.utils import weight_norm
 
 # Quantization and rate related utily functions
 
@@ -227,6 +233,32 @@
 
 # RDOVAE module and submodules
 
+sparsify_start     = 12000
+sparsify_stop      = 24000
+sparsify_interval  = 100
+sparsify_exponent  = 3
+#sparsify_start     = 0
+#sparsify_stop      = 0
+
+sparse_params1 = {
+#                'W_hr' : (1.0, [8, 4], True),
+#                'W_hz' : (1.0, [8, 4], True),
+#                'W_hn' : (1.0, [8, 4], True),
+                'W_ir' : (0.6, [8, 4], False),
+                'W_iz' : (0.4, [8, 4], False),
+                'W_in' : (0.8, [8, 4], False)
+                }
+
+sparse_params2 = {
+#                'W_hr' : (1.0, [8, 4], True),
+#                'W_hz' : (1.0, [8, 4], True),
+#                'W_hn' : (1.0, [8, 4], True),
+                'W_ir' : (0.3, [8, 4], False),
+                'W_iz' : (0.2, [8, 4], False),
+                'W_in' : (0.4, [8, 4], False)
+                }
+
+
 class MyConv(nn.Module):
     def __init__(self, input_dim, output_dim, dilation=1):
         super(MyConv, self).__init__()
@@ -239,6 +271,29 @@
         conv_in = torch.cat([torch.zeros_like(x[:,0:self.dilation,:], device=device), x], -2).permute(0, 2, 1)
         return torch.tanh(self.conv(conv_in)).permute(0, 2, 1)
 
+class GLU(nn.Module):
+    def __init__(self, feat_size):
+        super(GLU, self).__init__()
+
+        torch.manual_seed(5)
+
+        self.gate = weight_norm(nn.Linear(feat_size, feat_size, bias=False))
+
+        self.init_weights()
+
+    def init_weights(self):
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d)\
+            or isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
+                nn.init.orthogonal_(m.weight.data)
+
+    def forward(self, x):
+
+        out = x * torch.sigmoid(self.gate(x))
+
+        return out
+
 class CoreEncoder(nn.Module):
     STATE_HIDDEN = 128
     FRAMES_PER_STEP = 2
@@ -355,7 +410,11 @@
         self.gru5 = nn.GRU(608, 96, batch_first=True)
         self.conv5 = MyConv(704, 32)
         self.output  = nn.Linear(736, self.FRAMES_PER_STEP * self.output_dim)
-
+        self.glu1 = GLU(96)
+        self.glu2 = GLU(96)
+        self.glu3 = GLU(96)
+        self.glu4 = GLU(96)
+        self.glu5 = GLU(96)
         self.hidden_init = nn.Linear(self.state_size, 128)
         self.gru_init = nn.Linear(128, 480)
 
@@ -363,7 +422,17 @@
         print(f"decoder: {nb_params} weights")
         # initialize weights
         self.apply(init_weights)
+        self.sparsifier = []
+        self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+        self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+        self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+        self.sparsifier.append(GRUSparsifier([(self.gru4, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
+        self.sparsifier.append(GRUSparsifier([(self.gru5, sparse_params2)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent))
 
+    def sparsify(self):
+        for sparsifier in self.sparsifier:
+            sparsifier.step()
+
     def forward(self, z, initial_state):
 
         hidden = torch.tanh(self.hidden_init(initial_state))
@@ -377,15 +446,15 @@
         # run decoding layer stack
         x = n(torch.tanh(self.dense_1(z)))
 
-        x = torch.cat([x, n(self.gru1(x, h1_state)[0])], -1)
+        x = torch.cat([x, n(self.glu1(n(self.gru1(x, h1_state)[0])))], -1)
         x = torch.cat([x, n(self.conv1(x))], -1)
-        x = torch.cat([x, n(self.gru2(x, h2_state)[0])], -1)
+        x = torch.cat([x, n(self.glu2(n(self.gru2(x, h2_state)[0])))], -1)
         x = torch.cat([x, n(self.conv2(x))], -1)
-        x = torch.cat([x, n(self.gru3(x, h3_state)[0])], -1)
+        x = torch.cat([x, n(self.glu3(n(self.gru3(x, h3_state)[0])))], -1)
         x = torch.cat([x, n(self.conv3(x))], -1)
-        x = torch.cat([x, n(self.gru4(x, h4_state)[0])], -1)
+        x = torch.cat([x, n(self.glu4(n(self.gru4(x, h4_state)[0])))], -1)
         x = torch.cat([x, n(self.conv4(x))], -1)
-        x = torch.cat([x, n(self.gru5(x, h5_state)[0])], -1)
+        x = torch.cat([x, n(self.glu5(n(self.gru5(x, h5_state)[0])))], -1)
         x = torch.cat([x, n(self.conv5(x))], -1)
 
         # output layer and reshaping
@@ -489,6 +558,10 @@
     def clip_weights(self):
         if not type(self.weight_clip_fn) == type(None):
             self.apply(self.weight_clip_fn)
+
+    def sparsify(self):
+        #self.core_encoder.module.sparsify()
+        self.core_decoder.module.sparsify()
 
     def get_decoder_chunks(self, z_frames, mode='split', chunks_per_offset = 4):
 
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -84,7 +84,7 @@
 lr_decay_factor = args.lr_decay_factor
 split_mode = args.split_mode
 # not exposed
-adam_betas = [0.9, 0.99]
+adam_betas = [0.8, 0.95]
 adam_eps = 1e-8
 
 checkpoint['batch_size'] = batch_size
@@ -239,6 +239,7 @@
                 optimizer.step()
 
                 model.clip_weights()
+                model.sparsify()
 
                 scheduler.step()
 
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
 #define DRED_EXTENSION_ID 126
 
 /* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 7
+#define DRED_EXPERIMENTAL_VERSION 8
 #define DRED_EXPERIMENTAL_BYTES 2
 
 
--