shithub: opus

ref: 4f3761b0199df7024b6e6b2004fc5eb7a6dbb28b
dir: /dnn/torch/osce/models/shape_up_48.py/

View raw version

import torch
from torch import nn
import torch.nn.functional as F

import numpy as np

from utils.layers.silk_upsampler import SilkUpsampler
from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
from utils.layers.td_shaper import TDShaper
from utils.layers.deemph import Deemph
from utils.misc import freeze_model

from models.nns_base import NNSBase
from models.silk_feature_net_pl import SilkFeatureNetPL
from models.silk_feature_net import SilkFeatureNet
from .scale_embedding import ScaleEmbedding



class ShapeUp48(NNSBase):
    FRAME_SIZE16k=80

    def __init__(self,
                 num_features=47,
                 pitch_embedding_dim=64,
                 cond_dim=256,
                 pitch_max=257,
                 kernel_size=15,
                 preemph=0.85,
                 skip=288,
                 conv_gain_limits_db=[-6, 6],
                 numbits_range=[50, 650],
                 numbits_embedding_dim=8,
                 hidden_feature_dim=64,
                 partial_lookahead=True,
                 norm_p=2,
                 target_fs=48000,
                 noise_amplitude=0,
                 prenet=None,
                 avg_pool_k=4):

        super().__init__(skip=skip, preemph=preemph)


        self.num_features           = num_features
        self.cond_dim               = cond_dim
        self.pitch_max              = pitch_max
        self.pitch_embedding_dim    = pitch_embedding_dim
        self.kernel_size            = kernel_size
        self.preemph                = preemph
        self.skip                   = skip
        self.numbits_range          = numbits_range
        self.numbits_embedding_dim  = numbits_embedding_dim
        self.hidden_feature_dim     = hidden_feature_dim
        self.partial_lookahead      = partial_lookahead
        self.frame_size48           = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
        self.frame_size32           = self.FRAME_SIZE16k * 2
        self.noise_amplitude        = noise_amplitude
        self.prenet                 = prenet

        # freeze prenet if given
        if prenet is not None:
            freeze_model(self.prenet)
            try:
                self.deemph = Deemph(prenet.preemph)
            except:
                print("[warning] prenet model is expected to have preemph attribute")
                self.deemph = Deemph(0)



        # upsampler
        self.upsampler = SilkUpsampler()

        # pitch embedding
        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)

        # numbits embedding
        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)

        # feature net
        if partial_lookahead:
            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
        else:
            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)

        # non-linear transforms
        self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
        self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)

        # spectral shaping
        self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
        self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
        self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)


    def flop_count(self, rate=16000, verbose=False):

        frame_rate = rate / self.FRAME_SIZE16k

        # feature net
        feature_net_flops = self.feature_net.flop_count(frame_rate)
        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)

        if verbose:
            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")

        return feature_net_flops + af_flops

    def forward(self, x, features, periods, numbits, debug=False):

        if self.prenet is not None:
            with torch.no_grad():
                x = self.prenet(x, features, periods, numbits)
                x = self.deemph(x)



        periods         = periods.squeeze(-1)
        pitch_embedding = self.pitch_embedding(periods)
        numbits_embedding = self.numbits_embedding(numbits).flatten(2)

        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
        cf = self.feature_net(full_features)

        y32 = self.upsampler.hq_2x_up(x)

        noise = self.noise_amplitude * torch.randn_like(y32)
        noise = self.af_noise(noise, cf)

        y32 = self.af1(y32, cf, debug=debug)

        y32_1 = y32[:, 0:1, :]
        y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
        y32 = torch.cat((y32_1, y32_2, noise), dim=1)

        y32 = self.af2(y32, cf, debug=debug)

        y48 = self.upsampler.interpolate_3_2(y32)

        y48_1 = y48[:, 0:1, :]
        y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
        y48 = torch.cat((y48_1, y48_2), dim=1)

        y48 = self.af3(y48, cf, debug=debug)

        return y48