ref: 6a45b767e29e0b12a94c85f0f851dd18f9ac9aef
parent: cb7cf92a52b69c4728bb9fa819b66f60019c5fc7
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Apr 21 07:24:08 EDT 2025
Add skewed split for fine-tuning decoder
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -624,6 +624,8 @@
split_points = [start + stride * int(i * length / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
elif mode == 'random_split':
split_points = [stride * x + start for x in random_split(0, (stop - start)//stride - 1, chunks_per_offset - 1, 1)]
+ elif mode == 'skewed_split':
+ split_points = [start + stride * int(i * length / 4 / chunks_per_offset / stride) for i in range(chunks_per_offset)] + [stop]
else:
raise ValueError(f"get_decoder_chunks_generic: unknown mode {mode}")
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -63,7 +63,7 @@
training_group.add_argument('--sequence-length', type=int, help='sequence length, needs to be divisible by chunks_per_offset, default: 400', default=400)
training_group.add_argument('--chunks-per-offset', type=int, help='chunks per offset', default=4)
training_group.add_argument('--lr-decay-factor', type=float, help='learning rate decay factor, default: 2.5e-5', default=2.5e-5)
-training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split'], help='splitting mode for decoder input, default: split', default='split')
+training_group.add_argument('--split-mode', type=str, choices=['split', 'random_split', 'skewed_split'], help='splitting mode for decoder input, default: split', default='split')
training_group.add_argument('--enable-first-frame-loss', action='store_true', default=False, help='enables dedicated distortion loss on first 4 decoder frames')
training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
training_group.add_argument('--train-decoder-only', action='store_true', help='freeze encoder and statistical model and train decoder only')
--
⑨