ref: 735117b6d7697312962e99597eb516fbb3cc8360
parent: ffd1b0b137e82197f8f995e03763a65b36079eba
author: Jan Buethe <jbuethe@amazon.de>
date: Thu Feb 15 10:25:06 EST 2024
disabled sparse option in osce export script
--- a/dnn/torch/osce/export_model_weights.py
+++ b/dnn/torch/osce/export_model_weights.py
@@ -54,14 +54,14 @@
parser.add_argument('output_dir', type=str, help='output folder')
parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
-
+sparse_default=False
schedules = {
'nolace': [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
- ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
- ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
- ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
+ ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None)),
@@ -71,18 +71,18 @@
('af2', dict(quantize=True, scale=None)),
('af3', dict(quantize=True, scale=None)),
('af4', dict(quantize=True, scale=None)),
- ('post_cf1', dict(quantize=True, scale=None, sparse=True)),
- ('post_cf2', dict(quantize=True, scale=None, sparse=True)),
- ('post_af1', dict(quantize=True, scale=None, sparse=True)),
- ('post_af2', dict(quantize=True, scale=None, sparse=True)),
- ('post_af3', dict(quantize=True, scale=None, sparse=True))
+ ('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
],
'lace' : [
('pitch_embedding', dict()),
('feature_net.conv1', dict()),
- ('feature_net.conv2', dict(quantize=True, scale=None, sparse=True)),
- ('feature_net.tconv', dict(quantize=True, scale=None, sparse=True)),
- ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=True, recurrent_sparse=True)),
+ ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
+ ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
('cf1', dict(quantize=True, scale=None)),
('cf2', dict(quantize=True, scale=None)),
('af1', dict(quantize=True, scale=None))
--
⑨