ref: e12c7f584a542ec6f2b00be67bf33165e182b24b
dir: /dnn/torch/osce/stndrd/presentation/linear_prediction.ipynb/
sed: Output line too long { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation\n", "from scipy.io import wavfile\n", "import scipy.signal\n", "import torch\n", "\n", "from playback import make_playback_animation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def load_lpcnet_features(feature_file, version=2):\n", " if version == 2 or version == 3:\n", " layout = {\n", " 'cepstrum': [0,18],\n", " 'periods': [18, 19],\n", " 'pitch_corr': [19, 20],\n", " 'lpc': [20, 36]\n", " }\n", " frame_length = 36\n", "\n", " elif version == 1:\n", " layout = {\n", " 'cepstrum': [0,18],\n", " 'periods': [36, 37],\n", " 'pitch_corr': [37, 38],\n", " 'lpc': [39, 55],\n", " }\n", " frame_length = 55\n", " else:\n", " raise ValueError(f'unknown feature version: {version}')\n", "\n", "\n", " raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))\n", " raw_features = raw_features.reshape((-1, frame_length))\n", "\n", " features = torch.cat(\n", " [\n", " raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],\n", " raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]\n", " ],\n", " dim=1\n", " )\n", "\n", " lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]\n", " if version < 3:\n", " periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()\n", " else:\n", " periods = torch.round(torch.clip(256./2**(raw_features[:, layout['periods'][0] : layout['periods'][1]] + 1.5), 32, 256)).long()\n", "\n", " return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def run_lpc(signal, lpcs, frame_length=160):\n", " num_frames, lpc_order = lpcs.shape\n", "\n", " prediction = np.concatenate(\n", " [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]\n", " )\n", " error = signal[lpc_order :] - prediction\n", "\n", " return prediction, error" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "lpcnet_features = load_lpcnet_features('lp/features.f32')\n", "\n", "features = lpcnet_features['features'].numpy()\n", "periods = lpcnet_features['periods'].squeeze(-1).numpy()\n", "lpcs = lpcnet_features['lpcs'].numpy()\n", "\n", "x = np.fromfile('data/a3_short.pcm', dtype=np.int16).astype(np.float32) / 2**15\n", "x = np.concatenate((np.zeros(80), x, np.zeros(320)))\n", "x_preemph = x.copy()\n", "x_preemph[1:] -= 0.85 * x_preemph[:-1]\n", "\n", "num_frames = features.shape[0]\n", "x = x[:160 * num_frames]\n", "x_preemph = x_preemph[:160 * num_frames]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# short-term prediction\n", "pred, error = run_lpc(np.concatenate((np.zeros(16), x_preemph)), lpcs)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# long-term prediction\n", "offset = 256\n", "padded_error = np.concatenate((np.zeros(offset), error))\n", "ltp_error = padded_error.copy()\n", "for i, p in enumerate(list(periods)):\n", " t0 = i * 160 + offset\n", " t1 = t0 + 160\n", " \n", " past = padded_error[t0 - p : t1 - p]\n", " current = padded_error[t0 : t1]\n", " \n", " gain = np.dot(past, current) / (np.dot(past, past) + 1e-6)\n", " ltp_error[t0 : t1] -= gain * past\n", " \n", " \n", "ltp_error = ltp_error[offset:]\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": {