ref: ec04a94eb2000e86b90c186ff56d518c260f2d60
parent: 5f8201c71ee81a9eb7bb302d1bd72fe15678d213
author: Jan Buethe <jbuethe@amazon.de>
date: Mon Jan 22 10:12:52 EST 2024
bugfix in SilkFeatureNetPL
--- a/dnn/torch/osce/models/silk_feature_net_pl.py
+++ b/dnn/torch/osce/models/silk_feature_net_pl.py
@@ -66,18 +66,17 @@
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
- gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
- self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
+ self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
if softquant:
self.conv2 = soft_quant(self.conv2)
- if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv)
+ self.tconv = soft_quant(self.tconv)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
if sparsify:
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
- if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
+ mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{
--
⑨