shithub: opus

Download patch

ref: 735117b6d7697312962e99597eb516fbb3cc8360
parent: ffd1b0b137e82197f8f995e03763a65b36079eba
author: Jan Buethe <jbuethe@amazon.de>
date: Thu Feb 15 10:25:06 EST 2024

disabled sparse option in osce export script

--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -54,14 +54,14 @@
 parser.add_argument('output_dir', type=str, help='output folder')
 parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
 
-
+sparse_default=False
 schedules = {
     'nolace': [
         ('pitch_embedding', dict()),
         ('feature_net.conv1', dict()),
-        ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
-        ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
-        ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
+        ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
         ('cf1', dict(quantize=True, scale=None)),
         ('cf2', dict(quantize=True, scale=None)),
         ('af1', dict(quantize=True, scale=None)),
@@ -71,18 +71,18 @@
         ('af2', dict(quantize=True, scale=None)),
         ('af3', dict(quantize=True, scale=None)),
         ('af4', dict(quantize=True, scale=None)),
-        ('post_cf1', dict(quantize=True, scale=None, sparse=True)),
-        ('post_cf2', dict(quantize=True, scale=None, sparse=True)),
-        ('post_af1', dict(quantize=True, scale=None, sparse=True)),
-        ('post_af2', dict(quantize=True, scale=None, sparse=True)),
-        ('post_af3', dict(quantize=True, scale=None, sparse=True))
+        ('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
     ],
     'lace' : [
         ('pitch_embedding', dict()),
         ('feature_net.conv1', dict()),
-        ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
-        ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
-        ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
+        ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
+        ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
         ('cf1', dict(quantize=True, scale=None)),
         ('cf2', dict(quantize=True, scale=None)),
         ('af1', dict(quantize=True, scale=None))
--