shithub: opus

ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/models/multi_rate_lpcnet.py/

View raw version
import torch
from torch import nn
from utils.layers.subconditioner import get_subconditioner
from utils.layers import DualFC

from utils.ulaw import lin2ulawq, ulaw2lin
from utils.sample import sample_excitation
from utils.pcm import clip_to_int16
from utils.sparsification import GRUSparsifier, calculate_gru_flops_per_step

from utils.misc import interleave_tensors




# MultiRateLPCNet
class MultiRateLPCNet(nn.Module):
    def __init__(self, config):
        super(MultiRateLPCNet, self).__init__()

        # general parameters
        self.input_layout       = config['input_layout']
        self.feature_history    = config['feature_history']
        self.feature_lookahead  = config['feature_lookahead']
        self.signals            = config['signals']

        # frame rate network parameters
        self.feature_dimension          = config['feature_dimension']
        self.period_embedding_dim       = config['period_embedding_dim']
        self.period_levels              = config['period_levels']
        self.feature_channels           = self.feature_dimension + self.period_embedding_dim
        self.feature_conditioning_dim   = config['feature_conditioning_dim']
        self.feature_conv_kernel_size   = config['feature_conv_kernel_size']

        # frame rate network layers
        self.period_embedding   = nn.Embedding(self.period_levels, self.period_embedding_dim)
        self.feature_conv1      = nn.Conv1d(self.feature_channels, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
        self.feature_conv2      = nn.Conv1d(self.feature_conditioning_dim, self.feature_conditioning_dim, self.feature_conv_kernel_size, padding='valid')
        self.feature_dense1     = nn.Linear(self.feature_conditioning_dim, self.feature_conditioning_dim)
        self.feature_dense2     = nn.Linear(*(2*[self.feature_conditioning_dim]))

        # sample rate network parameters
        self.frame_size             = config['frame_size']
        self.signal_levels          = config['signal_levels']
        self.signal_embedding_dim   = config['signal_embedding_dim']
        self.gru_a_units            = config['gru_a_units']
        self.gru_b_units            = config['gru_b_units']
        self.output_levels          = config['output_levels']

        # subconditioning B
        sub_config = config['subconditioning']['subconditioning_b']
        self.substeps_b = sub_config['number_of_subsamples']
        self.subcondition_signals_b = sub_config['signals']
        self.signals_idx_b = [self.input_layout['signals'][key] for key in sub_config['signals']]
        method = sub_config['method']
        kwargs = sub_config['kwargs']
        if type(kwargs) == type(None):
            kwargs = dict()

        state_size = self.gru_b_units
        self.subconditioner_b = get_subconditioner(method,
            sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
            state_size, self.signal_levels, len(sub_config['signals']),
            **sub_config['kwargs'])

         # subconditioning A
        sub_config = config['subconditioning']['subconditioning_a']
        self.substeps_a = sub_config['number_of_subsamples']
        self.subcondition_signals_a = sub_config['signals']
        self.signals_idx_a = [self.input_layout['signals'][key] for key in sub_config['signals']]
        method = sub_config['method']
        kwargs = sub_config['kwargs']
        if type(kwargs) == type(None):
            kwargs = dict()

        state_size = self.gru_a_units
        self.subconditioner_a = get_subconditioner(method,
            sub_config['number_of_subsamples'], sub_config['pcm_embedding_size'],
            state_size, self.signal_levels, self.substeps_b * len(sub_config['signals']),
            **sub_config['kwargs'])


        # wrap up subconditioning, group_size_gru_a holds the number
        # of timesteps that are grouped as sample input for GRU A
        # input and group_size_subcondition_a holds the number of samples that are
        # grouped as input to pre-GRU B subconditioning
        self.group_size_gru_a = self.substeps_a * self.substeps_b
        self.group_size_subcondition_a = self.substeps_b
        self.gru_a_rate_divider = self.group_size_gru_a
        self.gru_b_rate_divider = self.substeps_b

        # gru sizes
        self.gru_a_input_dim        = self.group_size_gru_a * len(self.signals) * self.signal_embedding_dim + self.feature_conditioning_dim
        self.gru_b_input_dim        = self.subconditioner_a.get_output_dim(0) + self.feature_conditioning_dim
        self.signals_idx            = [self.input_layout['signals'][key] for key in self.signals]

        # sample rate network layers
        self.signal_embedding   = nn.Embedding(self.signal_levels, self.signal_embedding_dim)
        self.gru_a              = nn.GRU(self.gru_a_input_dim, self.gru_a_units, batch_first=True)
        self.gru_b              = nn.GRU(self.gru_b_input_dim, self.gru_b_units, batch_first=True)

        # sparsification
        self.sparsifier = []

        # GRU A
        if 'gru_a' in config['sparsification']:
            gru_config  = config['sparsification']['gru_a']
            task_list = [(self.gru_a, gru_config['params'])]
            self.sparsifier.append(GRUSparsifier(task_list,
                                                 gru_config['start'],
                                                 gru_config['stop'],
                                                 gru_config['interval'],
                                                 gru_config['exponent'])
            )
            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a,
                                                                      gru_config['params'], drop_input=True)
        else:
            self.gru_a_flops_per_step = calculate_gru_flops_per_step(self.gru_a, drop_input=True)

        # GRU B
        if 'gru_b' in config['sparsification']:
            gru_config  = config['sparsification']['gru_b']
            task_list = [(self.gru_b, gru_config['params'])]
            self.sparsifier.append(GRUSparsifier(task_list,
                                                 gru_config['start'],
                                                 gru_config['stop'],
                                                 gru_config['interval'],
                                                 gru_config['exponent'])
            )
            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b,
                                                                      gru_config['params'])
        else:
            self.gru_b_flops_per_step = calculate_gru_flops_per_step(self.gru_b)



        # dual FCs
        self.dual_fc = []
        for i in range(self.substeps_b):
            dim = self.subconditioner_b.get_output_dim(i)
            self.dual_fc.append(DualFC(dim, self.output_levels))
            self.add_module(f"dual_fc_{i}", self.dual_fc[-1])

    def get_gflops(self, fs, verbose=False, hierarchical_sampling=False):
        gflops = 0

        # frame rate network
        conditioning_dim = self.feature_conditioning_dim
        feature_channels = self.feature_channels
        frame_rate = fs / self.frame_size
        frame_rate_network_complexity = 1e-9 * 2 * (5 * conditioning_dim + 3 * feature_channels) * conditioning_dim * frame_rate
        if verbose:
            print(f"frame rate network: {frame_rate_network_complexity} GFLOPS")
        gflops += frame_rate_network_complexity

        # gru a
        gru_a_rate = fs / self.group_size_gru_a
        gru_a_complexity = 1e-9 * gru_a_rate * self.gru_a_flops_per_step
        if verbose:
            print(f"gru A: {gru_a_complexity} GFLOPS")
        gflops += gru_a_complexity

        # subconditioning a
        subcond_a_rate = fs / self.substeps_b
        subconditioning_a_complexity = 1e-9 * self.subconditioner_a.get_average_flops_per_step() * subcond_a_rate
        if verbose:
            print(f"subconditioning A: {subconditioning_a_complexity} GFLOPS")
        gflops += subconditioning_a_complexity

        # gru b
        gru_b_rate = fs / self.substeps_b
        gru_b_complexity = 1e-9 * gru_b_rate * self.gru_b_flops_per_step
        if verbose:
            print(f"gru B: {gru_b_complexity} GFLOPS")
        gflops += gru_b_complexity

        # subconditioning b
        subcond_b_rate = fs
        subconditioning_b_complexity = 1e-9 * self.subconditioner_b.get_average_flops_per_step() * subcond_b_rate
        if verbose:
            print(f"subconditioning B: {subconditioning_b_complexity} GFLOPS")
        gflops += subconditioning_b_complexity

        # dual fcs
        for i, fc in enumerate(self.dual_fc):
            rate = fs / len(self.dual_fc)
            input_size = fc.dense1.in_features
            output_size = fc.dense1.out_features
            dual_fc_complexity = 1e-9 *  (4 * input_size * output_size + 22 * output_size) * rate
            if hierarchical_sampling:
                dual_fc_complexity /= 8
            if verbose:
                print(f"dual_fc_{i}: {dual_fc_complexity} GFLOPS")
            gflops += dual_fc_complexity

        if verbose:
            print(f'total: {gflops} GFLOPS')

        return gflops



    def sparsify(self):
        for sparsifier in self.sparsifier:
            sparsifier.step()

    def frame_rate_network(self, features, periods):

        embedded_periods = torch.flatten(self.period_embedding(periods), 2, 3)
        features = torch.concat((features, embedded_periods), dim=-1)

        # convert to channels first and calculate conditioning vector
        c = torch.permute(features, [0, 2, 1])

        c = torch.tanh(self.feature_conv1(c))
        c = torch.tanh(self.feature_conv2(c))
        # back to channels last
        c = torch.permute(c, [0, 2, 1])
        c = torch.tanh(self.feature_dense1(c))
        c = torch.tanh(self.feature_dense2(c))

        return c

    def prepare_signals(self, signals, group_size, signal_idx):
        """ extracts, delays and groups signals """

        batch_size, sequence_length, num_signals = signals.shape

        # extract signals according to position
        signals = torch.cat([signals[:, :, i : i + 1] for i in signal_idx],
                            dim=-1)

        # roll back pcm to account for grouping
        signals  = torch.roll(signals, group_size - 1, -2)

        # reshape
        signals = torch.reshape(signals,
            (batch_size, sequence_length // group_size, group_size * len(signal_idx)))

        return signals


    def sample_rate_network(self, signals, c, gru_states):

        signals_a        = self.prepare_signals(signals, self.group_size_gru_a, self.signals_idx)
        embedded_signals = torch.flatten(self.signal_embedding(signals_a), 2, 3)
        # features at GRU A rate
        c_upsampled_a    = torch.repeat_interleave(c, self.frame_size // self.gru_a_rate_divider, dim=1)
        # features at GRU B rate
        c_upsampled_b    = torch.repeat_interleave(c, self.frame_size // self.gru_b_rate_divider, dim=1)

        y = torch.concat((embedded_signals, c_upsampled_a), dim=-1)
        y, gru_a_state = self.gru_a(y, gru_states[0])
        # first round of upsampling and subconditioning
        c_signals_a = self.prepare_signals(signals, self.group_size_subcondition_a, self.signals_idx_a)
        y = self.subconditioner_a(y, c_signals_a)
        y = interleave_tensors(y)

        y = torch.concat((y, c_upsampled_b), dim=-1)
        y, gru_b_state = self.gru_b(y, gru_states[1])
        c_signals_b = self.prepare_signals(signals, 1, self.signals_idx_b)
        y = self.subconditioner_b(y, c_signals_b)

        y = [self.dual_fc[i](y[i]) for i in range(self.substeps_b)]
        y = interleave_tensors(y)

        return y, (gru_a_state, gru_b_state)

    def decoder(self, signals, c, gru_states):
        embedded_signals = torch.flatten(self.signal_embedding(signals), 2, 3)

        y = torch.concat((embedded_signals, c), dim=-1)
        y, gru_a_state = self.gru_a(y, gru_states[0])
        y = torch.concat((y, c), dim=-1)
        y, gru_b_state = self.gru_b(y, gru_states[1])

        y = self.dual_fc(y)

        return torch.softmax(y, dim=-1), (gru_a_state, gru_b_state)

    def forward(self, features, periods, signals, gru_states):

        c           = self.frame_rate_network(features, periods)
        y, _        = self.sample_rate_network(signals, c, gru_states)
        log_probs   = torch.log_softmax(y, dim=-1)

        return log_probs

    def generate(self, features, periods, lpcs):

        with torch.no_grad():
            device = self.parameters().__next__().device

            num_frames          = features.shape[0] - self.feature_history - self.feature_lookahead
            lpc_order           = lpcs.shape[-1]
            num_input_signals   = len(self.signals)
            pitch_corr_position = self.input_layout['features']['pitch_corr'][0]

            # signal buffers
            last_signal       = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
            prediction        = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
            last_error        = torch.zeros((num_frames * self.frame_size + lpc_order + 1))
            output            = torch.zeros((num_frames * self.frame_size), dtype=torch.int16)
            mem = 0

            # state buffers
            gru_a_state = torch.zeros((1, 1, self.gru_a_units))
            gru_b_state = torch.zeros((1, 1, self.gru_b_units))

            input_signals = 128 + torch.zeros(self.group_size_gru_a * num_input_signals, dtype=torch.long)
            # conditioning signals for subconditioner a
            c_signals_a   = 128 + torch.zeros(self.group_size_subcondition_a * len(self.signals_idx_a), dtype=torch.long)
            # conditioning signals for subconditioner b
            c_signals_b   = 128 + torch.zeros(len(self.signals_idx_b), dtype=torch.long)

            # signal dict
            signal_dict = {
                'prediction'    : prediction,
                'last_error'    : last_error,
                'last_signal'   : last_signal
            }

            # push data to device
            features = features.to(device)
            periods  = periods.to(device)
            lpcs     = lpcs.to(device)

            # run feature encoding
            c = self.frame_rate_network(features.unsqueeze(0), periods.unsqueeze(0))

            for frame_index in range(num_frames):
                frame_start = frame_index * self.frame_size
                pitch_corr  = features[frame_index + self.feature_history, pitch_corr_position]
                a           = - torch.flip(lpcs[frame_index + self.feature_history], [0])
                current_c   = c[:, frame_index : frame_index + 1, :]

                for i in range(0, self.frame_size, self.group_size_gru_a):
                    pcm_position    = frame_start + i + lpc_order
                    output_position = frame_start + i

                    # calculate newest prediction
                    prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)

                    # prepare input
                    for slot in range(self.group_size_gru_a):
                        k = slot - self.group_size_gru_a + 1
                        for idx, name in enumerate(self.signals):
                            input_signals[idx + slot * num_input_signals] = lin2ulawq(
                                signal_dict[name][pcm_position + k]
                            )


                    # run GRU A
                    embed_signals   = self.signal_embedding(input_signals.reshape((1, 1, -1)))
                    embed_signals   = torch.flatten(embed_signals, 2)
                    y               = torch.cat((embed_signals, current_c), dim=-1)
                    h_a, gru_a_state  = self.gru_a(y, gru_a_state)

                    # loop over substeps_a
                    for step_a in range(self.substeps_a):
                        # prepare conditioning input
                        for slot in range(self.group_size_subcondition_a):
                            k = slot - self.group_size_subcondition_a + 1
                            for idx, name in enumerate(self.subcondition_signals_a):
                                c_signals_a[idx + slot * num_input_signals] = lin2ulawq(
                                    signal_dict[name][pcm_position + k]
                                )

                        # subconditioning
                        h_a = self.subconditioner_a.single_step(step_a, h_a, c_signals_a.reshape((1, 1, -1)))

                        # run GRU B
                        y = torch.cat((h_a, current_c), dim=-1)
                        h_b, gru_b_state = self.gru_b(y, gru_b_state)

                        # loop over substeps b
                        for step_b in range(self.substeps_b):
                            # prepare subconditioning input
                            for idx, name in enumerate(self.subcondition_signals_b):
                                c_signals_b[idx] = lin2ulawq(
                                    signal_dict[name][pcm_position]
                                )

                            # subcondition
                            h_b = self.subconditioner_b.single_step(step_b, h_b, c_signals_b.reshape((1, 1, -1)))

                            # run dual FC
                            probs = torch.softmax(self.dual_fc[step_b](h_b), dim=-1)

                            # sample
                            new_exc = ulaw2lin(sample_excitation(probs, pitch_corr))

                            # update signals
                            sig = new_exc + prediction[pcm_position]
                            last_error[pcm_position + 1] = new_exc
                            last_signal[pcm_position + 1] = sig

                            mem = 0.85 * mem + float(sig)
                            output[output_position] = clip_to_int16(round(mem))

                            # increase positions
                            pcm_position += 1
                            output_position += 1

                            # calculate next prediction
                            prediction[pcm_position] = torch.sum(last_signal[pcm_position - lpc_order + 1: pcm_position + 1] * a)

        return output