ref: 0dc559f060db0d62d95f424e3fd26a5f673b2f6b
parent: 5667867fa293dbbc632a7c78308f4ad2db9c52df
author: Jan Buethe <jbuethe@amazon.de>
date: Wed Apr 24 08:17:51 EDT 2024
added some bwe-related stuff
--- /dev/null
+++ b/dnn/torch/osce/losses/td_lowpass.py
@@ -1,0 +1,34 @@
+import torch
+import scipy.signal
+
+
+from utils.layers.fir import FIR
+
+class TDLowpass(torch.nn.Module):
+ def __init__(self, numtaps, cutoff, power=2):
+ super().__init__()
+
+ self.b = scipy.signal.firwin(numtaps, cutoff)
+ self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
+ self.power = power
+
+ def forward(self, y_true, y_pred):
+
+ assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
+
+ diff = y_true - y_pred
+ diff_lp = torch.nn.functional.conv1d(diff, self.weight)
+
+ loss = torch.mean(torch.abs(diff_lp ** self.power))
+
+ return loss, diff_lp
+
+ def get_freqz(self):
+ freq, response = scipy.signal.freqz(self.b)
+
+ return freq, response
+
+
+
+
+
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/silk_16_to_48.py
@@ -1,0 +1,28 @@
+import argparse
+
+from scipy.io import wavfile
+import torch
+import numpy as np
+
+from utils.layers.silk_upsampler import SilkUpsampler
+
+parser = argparse.ArgumentParser()
+parser.add_argument("input", type=str, help="input wave file")+parser.add_argument("output", type=str, help="output wave file")+
+if __name__ == "__main__":
+ args = parser.parse_args()
+
+ fs, x = wavfile.read(args.input)
+
+ # being lazy for now
+ assert fs == 16000 and x.dtype == np.int16
+
+ x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
+
+ upsampler = SilkUpsampler()
+ y = upsampler(x)
+
+ y = y.squeeze().numpy().astype(np.int16)
+
+ wavfile.write(args.output, 48000, y[13:])
\ No newline at end of file
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/fir.py
@@ -1,0 +1,27 @@
+import numpy as np
+import scipy.signal
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class FIR(nn.Module):
+ def __init__(self, numtaps, bands, desired, fs=2):
+ super().__init__()
+
+ if numtaps % 2 == 0:
+ print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")+ numtaps += 1
+
+ a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
+
+ self.weight = torch.from_numpy(a.astype(np.float32))
+
+ def forward(self, x):
+ num_channels = x.size(1)
+
+ weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
+
+ y = F.conv1d(x, weight, groups=num_channels)
+
+ return y
\ No newline at end of file
--
⑨