ref: cb7cf92a52b69c4728bb9fa819b66f60019c5fc7
parent: a41a344a2e30455ce4d1f6662b85332a70dc4b52
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Apr 21 07:23:29 EDT 2025
DRED: Add lambda schedule for first epochs
--- a/dnn/torch/rdovae/train_rdovae.py
+++ b/dnn/torch/rdovae/train_rdovae.py
@@ -163,6 +163,7 @@
# training loop
+ batch = 1
for epoch in range(1, epochs + 1):
print(f"training epoch {epoch}...")
@@ -203,13 +204,20 @@
outputs_soft_quant = model_output['outputs_soft_quant']
statistical_model = model_output['statistical_model']
+ if type(args.initial_checkpoint) == type(None):
+ latent_lambda = (1. - .5/(1.+batch/1000))
+ state_lambda = (1. - .9/(1.+batch/6000))
+ else:
+ latent_lambda = 1.
+ state_lambda = 1.
+
# rate loss
hard_rate = hard_rate_estimate(z, statistical_model['r_hard'][:,:,:latent_dim], statistical_model['theta_hard'][:,:,:latent_dim], reduce=False)
soft_rate = soft_rate_estimate(z, statistical_model['r_soft'][:,:,:latent_dim], reduce=False)
states_hard_rate = hard_rate_estimate(states, statistical_model['r_hard'][:,:,latent_dim:], statistical_model['theta_hard'][:,:,latent_dim:], reduce=False)
states_soft_rate = soft_rate_estimate(states, statistical_model['r_soft'][:,:,latent_dim:], reduce=False)
- soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (soft_rate + .02*states_soft_rate))
- hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (hard_rate + .02*states_hard_rate))
+ soft_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*soft_rate + .04*state_lambda*states_soft_rate))
+ hard_rate_loss = torch.mean(torch.sqrt(rate_lambda) * (latent_lambda*hard_rate + .04*state_lambda*states_hard_rate))
rate_loss = (soft_rate_loss + 0.1 * hard_rate_loss)
hard_rate_metric = torch.mean(hard_rate)
states_rate_metric = torch.mean(states_hard_rate)
@@ -272,6 +280,7 @@
rateloss_soft=running_soft_rate_loss / (i + 1)
)
previous_total_loss = running_total_loss
+ batch = batch+1
# save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
--
⑨