shithub: opus

Download patch

ref: 9e76a7bfb835ebe7cb97cf24da98462b78de0207
parent: d1c5b32add990473df84e42a8db64851b2dd65f6
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Oct 9 20:51:57 EDT 2023

update fargan to match version 45

--- a/dnn/torch/fargan/adv_train_fargan.py
+++ b/dnn/torch/fargan/adv_train_fargan.py
@@ -132,6 +132,10 @@
 
 spect_loss =  MultiResolutionSTFTLoss(device).to(device)
 
+for param in model.parameters():
+    param.requires_grad = False
+
+batch_count = 0
 if __name__ == '__main__':
     model.to(device)
     disc.to(device)
@@ -153,22 +157,28 @@
         print(f"training epoch {epoch}...")
         with tqdm.tqdm(dataloader, unit='batch') as tepoch:
             for i, (features, periods, target, lpc) in enumerate(tepoch):
+                if epoch == 1 and i == 400:
+                    for param in model.parameters():
+                        param.requires_grad = True
+
                 optimizer.zero_grad()
                 features = features.to(device)
-                lpc = lpc.to(device)
+                #lpc = lpc.to(device)
+                #lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
+                #lpc = fargan.interp_lpc(lpc, 4)
                 periods = periods.to(device)
                 if True:
                     target = target[:, :sequence_length*160]
-                    lpc = lpc[:,:sequence_length,:]
+                    #lpc = lpc[:,:sequence_length*4,:]
                     features = features[:,:sequence_length+4,:]
                     periods = periods[:,:sequence_length+4]
                 else:
                     target=target[::2, :]
-                    lpc=lpc[::2,:]
+                    #lpc=lpc[::2,:]
                     features=features[::2,:]
                     periods=periods[::2,:]
                 target = target.to(device)
-                target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
+                #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
 
                 #nb_pre = random.randrange(1, 6)
                 nb_pre = 2
@@ -208,7 +218,7 @@
 
                 cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
                 specc_loss = spect_loss(output, target.detach())
-                reg_loss = args.reg_weight * (.00*cont_loss + specc_loss)
+                reg_loss = (.00*cont_loss + specc_loss)
 
                 loss_gen = 0
                 for scale in scores_gen:
@@ -216,7 +226,8 @@
 
                 feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)
 
-                gen_loss = reg_loss +  feat_loss + loss_gen
+                reg_weight = args.reg_weight + 15./(1 + (batch_count/7600.))
+                gen_loss = reg_weight * reg_loss +  feat_loss + loss_gen
 
                 model.zero_grad()
 
@@ -238,6 +249,7 @@
 
 
                 tepoch.set_postfix(cont_loss=f"{running_cont_loss/(i+1):8.5f}",
+                                   reg_weight=f"{reg_weight:8.5f}",
                                    gen_loss=f"{running_gen_loss/(i+1):8.5f}",
                                    disc_loss=f"{running_disc_loss/(i+1):8.5f}",
                                    fmap_loss=f"{running_fmap_loss/(i+1):8.5f}",
@@ -244,6 +256,7 @@
                                    reg_loss=f"{running_reg_loss/(i+1):8.5f}",
                                    wc = f"{running_wc/(i+1):8.5f}",
                                    )
+                batch_count = batch_count + 1
 
         # save checkpoint
         checkpoint_path = os.path.join(checkpoint_dir, f'fargan{args.suffix}_adv_{epoch}.pth')
--- a/dnn/torch/fargan/dataset.py
+++ b/dnn/torch/fargan/dataset.py
@@ -1,5 +1,6 @@
 import torch
 import numpy as np
+import fargan
 
 class FARGANDataset(torch.utils.data.Dataset):
     def __init__(self,
@@ -34,7 +35,8 @@
         sizeof = self.features.strides[-1]
         self.features = np.lib.stride_tricks.as_strided(self.features, shape=(self.nb_sequences, self.sequence_length*2+4, nb_features),
                                            strides=(self.sequence_length*self.nb_features*sizeof, self.nb_features*sizeof, sizeof))
-        self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
+        #self.periods = np.round(50*self.features[:,:,self.nb_used_features-2]+100).astype('int')
+        self.periods = np.round(np.clip(256./2**(self.features[:,:,self.nb_used_features-2]+1.5), 32, 255)).astype('int')
 
         self.lpc = self.features[:, :, self.nb_used_features:]
         self.features = self.features[:, :, :self.nb_used_features]
@@ -51,5 +53,9 @@
             lpc = self.lpc[index, 4:, :].copy()
         data = self.data[index, :].copy().astype(np.float32) / 2**15
         periods = self.periods[index, :].copy()
+        #lpc = lpc*(self.gamma**np.arange(1,17))
+        #lpc=lpc[None,:,:]
+        #lpc = fargan.interp_lpc(lpc, 4)
+        #lpc=lpc[0,:,:]
 
         return features, periods, data, lpc
--- a/dnn/torch/fargan/fargan.py
+++ b/dnn/torch/fargan/fargan.py
@@ -4,6 +4,8 @@
 import torch.nn.functional as F
 import filters
 from torch.nn.utils import weight_norm
+#from convert_lsp import lpc_to_lsp, lsp_to_lpc
+from rc import lpc2rc, rc2lpc
 
 Fs = 16000
 
@@ -27,6 +29,27 @@
     p = y_pred/(1e-15+torch.norm(y_pred, dim=-1, p=2, keepdim=True))
     return torch.mean(1.-torch.sum(p*t, dim=-1))
 
+def interp_lpc(lpc, factor):
+    #print(lpc.shape)
+    #f = (np.arange(factor)+.5*((factor+1)%2))/factor
+    lsp = torch.atanh(lpc2rc(lpc))
+    #print("lsp0:")
+    #print(lsp)
+    shape = lsp.shape
+    #print("shape is", shape)
+    shape = (shape[0], shape[1]*factor, shape[2])
+    interp_lsp = torch.zeros(shape, device=lpc.device)
+    for k in range(factor):
+        f = (k+.5*((factor+1)%2))/factor
+        interp = (1-f)*lsp[:,:-1,:] + f*lsp[:,1:,:]
+        interp_lsp[:,factor//2+k:-(factor//2):factor,:] = interp
+    for k in range(factor//2):
+        interp_lsp[:,k,:] = interp_lsp[:,factor//2,:]
+    for k in range((factor+1)//2):
+        interp_lsp[:,-k-1,:] = interp_lsp[:,-(factor+3)//2,:]
+    #print("lsp:")
+    #print(interp_lsp)
+    return rc2lpc(torch.tanh(interp_lsp))
 
 def analysis_filter(x, lpc, nb_subframes=4, subframe_size=40, gamma=.9):
     device = x.device
@@ -39,9 +62,9 @@
     x = torch.reshape(x, (batch_size, nb_frames*nb_subframes, subframe_size))
     out = torch.zeros((batch_size, 0), device=device)
 
-    if gamma is not None:
-        bw = gamma**(torch.arange(1, 17, device=device))
-        lpc = lpc*bw[None,None,:]
+    #if gamma is not None:
+    #    bw = gamma**(torch.arange(1, 17, device=device))
+    #    lpc = lpc*bw[None,None,:]
     ones = torch.ones((*(lpc.shape[:-1]), 1), device=device)
     zeros = torch.zeros((*(lpc.shape[:-1]), subframe_size-1), device=device)
     a = torch.cat([ones, lpc], -1)
@@ -127,23 +150,27 @@
         out = self.glu(torch.tanh(self.conv(xcat)))
         return out, xcat[:,self.in_size:]
 
+def n(x):
+    return torch.clamp(x + (1./127.)*(torch.rand_like(x)-.5), min=-1., max=1.)
+
 class FARGANCond(nn.Module):
-    def __init__(self, feature_dim=20, cond_size=256, pembed_dims=64):
+    def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
         super(FARGANCond, self).__init__()
 
         self.feature_dim = feature_dim
         self.cond_size = cond_size
 
-        self.pembed = nn.Embedding(256, pembed_dims)
-        self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, self.cond_size, bias=False)
-        self.fconv1 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
-        self.fconv2 = nn.Conv1d(self.cond_size, self.cond_size, kernel_size=3, padding='valid', bias=False)
-        self.fdense2 = nn.Linear(self.cond_size, 80*4, bias=False)
+        self.pembed = nn.Embedding(224, pembed_dims)
+        self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
+        self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
+        self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False)
 
         self.apply(init_weights)
+        nb_params = sum(p.numel() for p in self.parameters())
+        print(f"cond model: {nb_params} weights")
 
     def forward(self, features, period):
-        p = self.pembed(period)
+        p = self.pembed(period-32)
         features = torch.cat((features, p), -1)
         tmp = torch.tanh(self.fdense1(features))
         tmp = tmp.permute(0, 2, 1)
@@ -150,7 +177,7 @@
         tmp = torch.tanh(self.fconv1(tmp))
         tmp = torch.tanh(self.fconv2(tmp))
         tmp = tmp.permute(0, 2, 1)
-        tmp = torch.tanh(self.fdense2(tmp))
+        #tmp = torch.tanh(self.fdense2(tmp))
         return tmp
 
 class FARGANSub(nn.Module):
@@ -160,70 +187,87 @@
         self.subframe_size = subframe_size
         self.nb_subframes = nb_subframes
         self.cond_size = cond_size
+        self.cond_gain_dense = nn.Linear(80, 1)
 
         #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
-        self.fwc0 = FWConv(4*self.subframe_size+80, self.cond_size)
-        self.sig_dense2 = nn.Linear(self.cond_size, self.cond_size, bias=False)
-        self.gru1 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
-        self.gru2 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
-        self.gru3 = nn.GRUCell(self.cond_size, self.cond_size, bias=False)
+        self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
+        self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+        self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False)
+        self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)
 
         self.dense1_glu = GLU(self.cond_size)
-        self.dense2_glu = GLU(self.cond_size)
         self.gru1_glu = GLU(self.cond_size)
-        self.gru2_glu = GLU(self.cond_size)
-        self.gru3_glu = GLU(self.cond_size)
-        self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
+        self.gru2_glu = GLU(128)
+        self.gru3_glu = GLU(128)
+        self.skip_glu = GLU(self.cond_size)
+        #self.ptaps_dense = nn.Linear(4*self.cond_size, 5)
 
-        self.sig_dense_out = nn.Linear(4*self.cond_size, self.subframe_size, bias=False)
-        self.gain_dense_out = nn.Linear(4*self.cond_size, 1)
+        self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
+        self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
+        self.gain_dense_out = nn.Linear(self.cond_size, 4)
 
 
         self.apply(init_weights)
+        nb_params = sum(p.numel() for p in self.parameters())
+        print(f"subframe model: {nb_params} weights")
 
-    def forward(self, cond, prev, exc_mem, phase, period, states, gain=None):
+    def forward(self, cond, prev_pred, exc_mem, period, states, gain=None):
         device = exc_mem.device
         #print(cond.shape, prev.shape)
 
-        dump_signal(prev, 'prev_in.f32')
-
-        idx = 256-torch.clamp(period[:,None], min=self.subframe_size+2, max=254)
+        cond = n(cond)
+        dump_signal(gain, 'gain0.f32')
+        gain = torch.exp(self.cond_gain_dense(cond))
+        dump_signal(gain, 'gain1.f32')
+        idx = 256-period[:,None]
         rng = torch.arange(self.subframe_size+4, device=device)
         idx = idx + rng[None,:] - 2
+        mask = idx >= 256
+        idx = idx - mask*period[:,None]
         pred = torch.gather(exc_mem, 1, idx)
-        pred = pred/(1e-5+gain)
+        pred = n(pred/(1e-5+gain))
 
-        prev = prev/(1e-5+gain)
+        prev = exc_mem[:,-self.subframe_size:]
+        dump_signal(prev, 'prev_in.f32')
+        prev = n(prev/(1e-5+gain))
         dump_signal(prev, 'pitch_exc.f32')
         dump_signal(exc_mem, 'exc_mem.f32')
 
-        tmp = torch.cat((cond, pred[:,2:-2], prev, phase), 1)
+        tmp = torch.cat((cond, pred, prev), 1)
+        #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
+        fpitch = pred[:,2:-2]
 
         #tmp = self.dense1_glu(torch.tanh(self.sig_dense1(tmp)))
         fwc0_out, fwc0_state = self.fwc0(tmp, states[3])
-        dense2_out = self.dense2_glu(torch.tanh(self.sig_dense2(fwc0_out)))
-        gru1_state = self.gru1(dense2_out, states[0])
-        gru1_out = self.gru1_glu(gru1_state)
-        gru2_state = self.gru2(gru1_out, states[1])
-        gru2_out = self.gru2_glu(gru2_state)
-        gru3_state = self.gru3(gru2_out, states[2])
-        gru3_out = self.gru3_glu(gru3_state)
-        gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, dense2_out], 1)
-        sig_out = torch.tanh(self.sig_dense_out(gru3_out))
+        fwc0_out = n(fwc0_out)
+        pitch_gain = torch.sigmoid(self.gain_dense_out(fwc0_out))
+
+        gru1_state = self.gru1(torch.cat([fwc0_out, pitch_gain[:,0:1]*fpitch, prev], 1), states[0])
+        gru1_out = self.gru1_glu(n(gru1_state))
+        gru1_out = n(gru1_out)
+        gru2_state = self.gru2(torch.cat([gru1_out, pitch_gain[:,1:2]*fpitch, prev], 1), states[1])
+        gru2_out = self.gru2_glu(n(gru2_state))
+        gru2_out = n(gru2_out)
+        gru3_state = self.gru3(torch.cat([gru2_out, pitch_gain[:,2:3]*fpitch, prev], 1), states[2])
+        gru3_out = self.gru3_glu(n(gru3_state))
+        gru3_out = n(gru3_out)
+        gru3_out = torch.cat([gru1_out, gru2_out, gru3_out, fwc0_out], 1)
+        skip_out = torch.tanh(self.skip_dense(torch.cat([gru3_out, pitch_gain[:,3:4]*fpitch, prev], 1)))
+        skip_out = self.skip_glu(n(skip_out))
+        sig_out = torch.tanh(self.sig_dense_out(skip_out))
         dump_signal(sig_out, 'exc_out.f32')
-        taps = self.ptaps_dense(gru3_out)
-        taps = .2*taps + torch.exp(taps)
-        taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
-        dump_signal(taps, 'taps.f32')
-        #fpitch = taps[:,0:1]*pred[:,:-4] + taps[:,1:2]*pred[:,1:-3] + taps[:,2:3]*pred[:,2:-2] + taps[:,3:4]*pred[:,3:-1] + taps[:,4:]*pred[:,4:]
-        fpitch = pred[:,2:-2]
+        #taps = self.ptaps_dense(gru3_out)
+        #taps = .2*taps + torch.exp(taps)
+        #taps = taps / (1e-2 + torch.sum(torch.abs(taps), dim=-1, keepdim=True))
+        #dump_signal(taps, 'taps.f32')
 
-        pitch_gain = torch.exp(self.gain_dense_out(gru3_out))
         dump_signal(pitch_gain, 'pgain.f32')
-        sig_out = (sig_out + pitch_gain*fpitch) * gain
+        #sig_out = (sig_out + pitch_gain*fpitch) * gain
+        sig_out = sig_out * gain
         exc_mem = torch.cat([exc_mem[:,self.subframe_size:], sig_out], 1)
+        prev_pred = torch.cat([prev_pred[:,self.subframe_size:], fpitch], 1)
         dump_signal(sig_out, 'sig_out.f32')
-        return sig_out, exc_mem, (gru1_state, gru2_state, gru3_state, fwc0_state)
+        return sig_out, exc_mem, prev_pred, (gru1_state, gru2_state, gru3_state, fwc0_state)
 
 class FARGAN(nn.Module):
     def __init__(self, subframe_size=40, nb_subframes=4, feature_dim=20, cond_size=256, passthrough_size=0, has_gain=False, gamma=None):
@@ -242,37 +286,30 @@
         device = features.device
         batch_size = features.size(0)
 
-        phase_real, phase_imag = gen_phase_embedding(period[:, 3:-1], self.frame_size)
-        #np.round(32000*phase.detach().numpy()).astype('int16').tofile('phase.sw')
-
-        prev = torch.zeros(batch_size, self.subframe_size, device=device)
+        prev = torch.zeros(batch_size, 256, device=device)
         exc_mem = torch.zeros(batch_size, 256, device=device)
         nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0
 
         states = (
             torch.zeros(batch_size, self.cond_size, device=device),
-            torch.zeros(batch_size, self.cond_size, device=device),
-            torch.zeros(batch_size, self.cond_size, device=device),
-            torch.zeros(batch_size, (4*self.subframe_size+80)*2, device=device)
+            torch.zeros(batch_size, 128, device=device),
+            torch.zeros(batch_size, 128, device=device),
+            torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
         )
 
         sig = torch.zeros((batch_size, 0), device=device)
         cond = self.cond_net(features, period)
         if pre is not None:
-            prev[:,:] = pre[:, self.frame_size-self.subframe_size : self.frame_size]
             exc_mem[:,-self.frame_size:] = pre[:, :self.frame_size]
         start = 1 if nb_pre_frames>0 else 0
         for n in range(start, nb_frames+nb_pre_frames):
             for k in range(self.nb_subframes):
                 pos = n*self.frame_size + k*self.subframe_size
-                preal = phase_real[:, pos:pos+self.subframe_size]
-                pimag = phase_imag[:, pos:pos+self.subframe_size]
-                phase = torch.cat([preal, pimag], 1)
                 #print("now: ", preal.shape, prev.shape, sig_in.shape)
                 pitch = period[:, 3+n]
                 gain = .03*10**(0.5*features[:, 3+n, 0:1]/np.sqrt(18.0))
                 #gain = gain[:,:,None]
-                out, exc_mem, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, phase, pitch, states, gain=gain)
+                out, exc_mem, prev, states = self.sig_net(cond[:, n, k*80:(k+1)*80], prev, exc_mem, pitch, states, gain=gain)
 
                 if n < nb_pre_frames:
                     out = pre[:, pos:pos+self.subframe_size]
@@ -280,6 +317,5 @@
                 else:
                     sig = torch.cat([sig, out], 1)
 
-                prev = out
         states = [s.detach() for s in states]
         return sig, states
--- /dev/null
+++ b/dnn/torch/fargan/rc.py
@@ -1,0 +1,29 @@
+import torch
+
+
+
+def rc2lpc(rc):
+    order = rc.shape[-1]
+    lpc=rc[...,0:1]
+    for i in range(1, order):
+        lpc = torch.cat([lpc + rc[...,i:i+1]*torch.flip(lpc,dims=(-1,)), rc[...,i:i+1]], -1)
+        #print("to:", lpc)
+    return lpc
+
+def lpc2rc(lpc):
+    order = lpc.shape[-1]
+    rc = lpc[...,-1:]
+    for i in range(order-1, 0, -1):
+        ki = lpc[...,-1:]
+        lpc = lpc[...,:-1]
+        lpc = (lpc - ki*torch.flip(lpc,dims=(-1,)))/(1 - ki*ki)
+        rc = torch.cat([lpc[...,-1:] , rc], -1)
+    return rc
+
+if __name__ == "__main__":
+    rc = torch.tensor([[.5, -.5, .6, -.6]])
+    print(rc)
+    lpc = rc2lpc(rc)
+    print(lpc)
+    rc2 = lpc2rc(lpc)
+    print(rc2)
--- a/dnn/torch/fargan/stft_loss.py
+++ b/dnn/torch/fargan/stft_loss.py
@@ -44,7 +44,9 @@
         Returns:
             Tensor: Spectral convergence loss value.
         """
-        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+        x_mag = torch.sqrt(x_mag)
+        y_mag = torch.sqrt(y_mag)
+        return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
 
 class LogSTFTMagnitudeLoss(torch.nn.Module):
     """Log STFT magnitude loss module."""
@@ -136,14 +138,14 @@
 
 class MultiResolutionSTFTLoss(torch.nn.Module):
 
-    def __init__(self,
+    '''def __init__(self,
                  device,
                  fft_sizes=[2048, 1024, 512, 256, 128, 64],
                  hop_sizes=[512, 256, 128, 64, 32, 16],
                  win_lengths=[2048, 1024, 512, 256, 128, 64],
-                 window="hann_window"):
+                 window="hann_window"):'''
 
-        '''def __init__(self,
+    '''def __init__(self,
                  device,
                  fft_sizes=[2048, 1024, 512, 256, 128, 64],
                  hop_sizes=[256, 128, 64, 32, 16, 8],
@@ -150,12 +152,12 @@
                  win_lengths=[1024, 512, 256, 128, 64, 32],
                  window="hann_window"):'''
 
-        '''def __init__(self,
+    def __init__(self,
                  device,
                  fft_sizes=[2560, 1280, 640, 320, 160, 80],
                  hop_sizes=[640, 320, 160, 80, 40, 20],
                  win_lengths=[2560, 1280, 640, 320, 160, 80],
-                 window="hann_window"):'''
+                 window="hann_window"):
 
         super(MultiResolutionSTFTLoss, self).__init__()
         assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
--- a/dnn/torch/fargan/test_fargan.py
+++ b/dnn/torch/fargan/test_fargan.py
@@ -48,8 +48,10 @@
 features = np.reshape(np.memmap(features_file, dtype='float32', mode='r'), (1, -1, nb_features))
 lpc = features[:,4-1:-1,nb_used_features:]
 features = features[:, :, :nb_used_features]
-periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
+#periods = np.round(50*features[:,:,nb_used_features-2]+100).astype('int')
+periods = np.round(np.clip(256./2**(features[:,:,nb_used_features-2]+1.5), 32, 255)).astype('int')
 
+
 nb_frames = features.shape[1]
 #nb_frames = 1000
 gamma = checkpoint['model_kwargs']['gamma']
@@ -90,18 +92,37 @@
         buffer[:] = out_sig_frame[-16:]
     return signal
 
+def inverse_perceptual_weighting40 (pw_signal, filters):
 
+    #inverse perceptual weighting= H_preemph / W(z/gamma)
 
+    signal = np.zeros_like(pw_signal)
+    buffer = np.zeros(16)
+    num_frames = pw_signal.shape[0] //40
+    assert num_frames == filters.shape[0]
+    for frame_idx in range(0, num_frames):
+        in_frame = pw_signal[frame_idx*40: (frame_idx+1)*40][:]
+        out_sig_frame = lpc_synthesis_one_frame(in_frame, filters[frame_idx, :], buffer)
+        signal[frame_idx*40: (frame_idx+1)*40] = out_sig_frame[:]
+        buffer[:] = out_sig_frame[-16:]
+    return signal
+
+from scipy.signal import lfilter
+
 if __name__ == '__main__':
     model.to(device)
     features = torch.tensor(features).to(device)
     #lpc = torch.tensor(lpc).to(device)
     periods = torch.tensor(periods).to(device)
+    weighting = gamma**np.arange(1, 17)
+    lpc = lpc*weighting
+    lpc = fargan.interp_lpc(torch.tensor(lpc), 4).numpy()
 
     sig, _ = model(features, periods, nb_frames - 4)
-    weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
+    #weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
     sig = sig.detach().numpy().flatten()
-    sig = inverse_perceptual_weighting(sig, lpc[0,:,:], weighting_vector)
+    sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
+    #sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
 
     pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
     pcm.tofile(signal_file)
--- a/dnn/torch/fargan/train_fargan.py
+++ b/dnn/torch/fargan/train_fargan.py
@@ -114,20 +114,25 @@
             for i, (features, periods, target, lpc) in enumerate(tepoch):
                 optimizer.zero_grad()
                 features = features.to(device)
-                lpc = lpc.to(device)
+                #lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
+                #print("interp size", lpc.shape)
+                #lpc = lpc.to(device)
+                #lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
+                #lpc = fargan.interp_lpc(lpc, 4)
                 periods = periods.to(device)
                 if (np.random.rand() > 0.1):
                     target = target[:, :sequence_length*160]
-                    lpc = lpc[:,:sequence_length,:]
+                    #lpc = lpc[:,:sequence_length*4,:]
                     features = features[:,:sequence_length+4,:]
                     periods = periods[:,:sequence_length+4]
                 else:
                     target=target[::2, :]
-                    lpc=lpc[::2,:]
+                    #lpc=lpc[::2,:]
                     features=features[::2,:]
                     periods=periods[::2,:]
                 target = target.to(device)
-                target = fargan.analysis_filter(target, lpc[:,:,:], gamma=args.gamma)
+                #print(target.shape, lpc.shape)
+                #target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
 
                 #nb_pre = random.randrange(1, 6)
                 nb_pre = 2
@@ -135,9 +140,9 @@
                 sig, states = model(features, periods, target.size(1)//160 - nb_pre, pre=pre, states=None)
                 sig = torch.cat([pre, sig], -1)
 
-                cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], sig[:, nb_pre*160:nb_pre*160+80])
+                cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+160], sig[:, nb_pre*160:nb_pre*160+160])
                 specc_loss = spect_loss(sig, target.detach())
-                loss = .00*cont_loss + specc_loss
+                loss = .03*cont_loss + specc_loss
 
                 loss.backward()
                 optimizer.step()
--