shithub: opus

Download patch

ref: cb7cf92a52b69c4728bb9fa819b66f60019c5fc7
parent: a41a344a2e30455ce4d1f6662b85332a70dc4b52
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Apr 21 07:23:29 EDT 2025

DRED: Add lambda schedule for first epochs

--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -163,6 +163,7 @@
 
     # training loop
 
+    batch = 1
     for epoch in range(1, epochs + 1):
 
         print(f"training epoch {epoch}...")
@@ -203,13 +204,20 @@
                 outputs_soft_quant  = model_output['outputs_soft_quant']
                 statistical_model   = model_output['statistical_model']
 
+                if type(args.initial_checkpoint) == type(None):
+                    latent_lambda = (1. - .5/(1.+batch/1000))
+                    state_lambda = (1. - .9/(1.+batch/6000))
+                else:
+                    latent_lambda = 1.
+                    state_lambda = 1.
+
                 # rate loss
                 hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
                 soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
                 states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
                 states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
-                soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
-                hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
+                soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
+                hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
                 rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
                 hard_rate_metric = torch.mean(hard_rate)
                 states_rate_metric = torch.mean(states_hard_rate)
@@ -272,6 +280,7 @@
                         rateloss_soft=running_soft_rate_loss / (i + 1)
                     )
                     previous_total_loss = running_total_loss
+                batch = batch+1
 
         # save checkpoint
         checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
--