shithub: opus

ref: 35ee397e060283d30c098ae5e17836316bbec08b
dir: /dnn/torch/lpcnet/data/lpcnet_dataset.py/

View raw version
""" Dataset for LPCNet training """
import os

import yaml
import torch
import numpy as np
from torch.utils.data import Dataset


scale = 255.0/32768.0
scale_1 = 32768.0/255.0
def ulaw2lin(u):
    u = u - 128
    s = np.sign(u)
    u = np.abs(u)
    return s*scale_1*(np.exp(u/128.*np.log(256))-1)


def lin2ulaw(x):
    s = np.sign(x)
    x = np.abs(x)
    u = (s*(128*np.log(1+scale*x)/np.log(256)))
    u = np.clip(128 + np.round(u), 0, 255)
    return u


def run_lpc(signal, lpcs, frame_length=160):
    num_frames, lpc_order = lpcs.shape

    prediction = np.concatenate(
        [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]
    )
    error = signal[lpc_order :] - prediction

    return prediction, error

class LPCNetDataset(Dataset):
    def __init__(self,
                 path_to_dataset,
                 features=['cepstrum', 'periods', 'pitch_corr'],
                 input_signals=['last_signal', 'prediction', 'last_error'],
                 target='error',
                 frames_per_sample=15,
                 feature_history=2,
                 feature_lookahead=2,
                 lpc_gamma=1):

        super(LPCNetDataset, self).__init__()

        # load dataset info
        self.path_to_dataset = path_to_dataset
        with open(os.path.join(path_to_dataset, 'info.yml'), 'r') as f:
            dataset = yaml.load(f, yaml.FullLoader)

        # dataset version
        self.version = dataset['version']
        if self.version == 1:
            self.getitem = self.getitem_v1
        elif self.version == 2:
            self.getitem = self.getitem_v2
        else:
            raise ValueError(f"dataset version {self.version} unknown")

        # features
        self.feature_history      = feature_history
        self.feature_lookahead    = feature_lookahead
        self.frame_offset         = 1 + self.feature_history
        self.frames_per_sample    = frames_per_sample
        self.input_features       = features
        self.feature_frame_layout = dataset['feature_frame_layout']
        self.lpc_gamma            = lpc_gamma

        # load feature file
        self.feature_file = os.path.join(path_to_dataset, dataset['feature_file'])
        self.features = np.memmap(self.feature_file, dtype=dataset['feature_dtype'])
        self.feature_frame_length = dataset['feature_frame_length']

        assert len(self.features) % self.feature_frame_length == 0
        self.features = self.features.reshape((-1, self.feature_frame_length))

        # derive number of samples is dataset
        self.dataset_length = (len(self.features) - self.frame_offset - self.feature_lookahead - 1) // self.frames_per_sample

        # signals
        self.frame_length               = dataset['frame_length']
        self.signal_frame_layout        = dataset['signal_frame_layout']
        self.input_signals              = input_signals
        self.target                     = target

        # load signals
        self.signal_file  = os.path.join(path_to_dataset, dataset['signal_file'])
        self.signals  = np.memmap(self.signal_file, dtype=dataset['signal_dtype'])
        self.signal_frame_length  = dataset['signal_frame_length']
        self.signals = self.signals.reshape((-1, self.signal_frame_length))
        assert len(self.signals) == len(self.features) * self.frame_length

    def __getitem__(self, index):
        return self.getitem(index)

    def getitem_v2(self, index):
        sample = dict()

        # extract features
        frame_start = self.frame_offset + index       * self.frames_per_sample - self.feature_history
        frame_stop  = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead

        for feature in self.input_features:
            feature_start, feature_stop = self.feature_frame_layout[feature]
            sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]

        # convert periods
        if 'periods' in self.input_features:
            sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')

        signal_start = (self.frame_offset + index       * self.frames_per_sample) * self.frame_length
        signal_stop  = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length

        # last_signal and signal are always expected to be there
        sample['last_signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
        sample['signal'] = self.signals[signal_start : signal_stop, self.signal_frame_layout['signal']]

        # calculate prediction and error if lpc coefficients present and prediction not given
        if 'lpc' in self.feature_frame_layout and 'prediction' not in self.signal_frame_layout:
            # lpc coefficients with one frame lookahead
            # frame positions (start one frame early for past excitation)
            frame_start = self.frame_offset + self.frames_per_sample * index - 1
            frame_stop  = self.frame_offset + self.frames_per_sample * (index + 1)

            # feature positions
            lpc_start, lpc_stop = self.feature_frame_layout['lpc']
            lpc_order = lpc_stop - lpc_start
            lpcs = self.features[frame_start : frame_stop, lpc_start : lpc_stop]

            # LPC weighting
            lpc_order = lpc_stop - lpc_start
            weights = np.array([self.lpc_gamma ** (i + 1) for i in range(lpc_order)])
            lpcs = lpcs * weights

            # signal position (lpc_order samples as history)
            signal_start = frame_start * self.frame_length - lpc_order + 1
            signal_stop  = frame_stop  * self.frame_length + 1
            noisy_signal = self.signals[signal_start : signal_stop, self.signal_frame_layout['last_signal']]
            clean_signal = self.signals[signal_start - 1 : signal_stop - 1, self.signal_frame_layout['signal']]

            noisy_prediction, noisy_error = run_lpc(noisy_signal, lpcs, frame_length=self.frame_length)

            # extract signals
            offset = self.frame_length
            sample['prediction'] = noisy_prediction[offset : offset + self.frame_length * self.frames_per_sample]
            sample['last_error'] = noisy_error[offset - 1 : offset - 1 + self.frame_length * self.frames_per_sample]
            # calculate error between real signal and noisy prediction


            sample['error'] = sample['signal'] - sample['prediction']


        # concatenate features
        feature_keys = [key for key in self.input_features if not key.startswith("periods")]
        features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
        signals = torch.cat([torch.LongTensor(lin2ulaw(sample[key])).unsqueeze(-1) for key in self.input_signals], dim=-1)
        target  = torch.LongTensor(lin2ulaw(sample[self.target]))
        periods = torch.LongTensor(sample['periods'])

        return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}

    def getitem_v1(self, index):
        sample = dict()

        # extract features
        frame_start = self.frame_offset + index       * self.frames_per_sample - self.feature_history
        frame_stop  = self.frame_offset + (index + 1) * self.frames_per_sample + self.feature_lookahead

        for feature in self.input_features:
            feature_start, feature_stop = self.feature_frame_layout[feature]
            sample[feature] = self.features[frame_start : frame_stop, feature_start : feature_stop]

        # convert periods
        if 'periods' in self.input_features:
            sample['periods'] = (0.1 + 50 * sample['periods'] + 100).astype('int16')

        signal_start = (self.frame_offset + index       * self.frames_per_sample) * self.frame_length
        signal_stop  = (self.frame_offset + (index + 1) * self.frames_per_sample) * self.frame_length

        # last_signal and signal are always expected to be there
        for signal_name, index in self.signal_frame_layout.items():
            sample[signal_name] = self.signals[signal_start : signal_stop, index]

        # concatenate features
        feature_keys = [key for key in self.input_features if not key.startswith("periods")]
        features = torch.concat([torch.FloatTensor(sample[key]) for key in feature_keys], dim=-1)
        signals = torch.cat([torch.LongTensor(sample[key]).unsqueeze(-1) for key in self.input_signals], dim=-1)
        target  = torch.LongTensor(sample[self.target])
        periods = torch.LongTensor(sample['periods'])

        return {'features' : features, 'periods' : periods, 'signals' : signals, 'target' : target}

    def __len__(self):
        return self.dataset_length