ref: 46f9c9c6698f64a9e63a68dcc507042322bd3df9
dir: /dnn/torch/osce/stndrd/presentation/linear_prediction.ipynb/
sed: Output line too long
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation\n",
    "from scipy.io import wavfile\n",
    "import scipy.signal\n",
    "import torch\n",
    "\n",
    "from playback import make_playback_animation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_lpcnet_features(feature_file, version=2):\n",
    "    if version == 2 or version == 3:\n",
    "        layout = {\n",
    "            'cepstrum': [0,18],\n",
    "            'periods': [18, 19],\n",
    "            'pitch_corr': [19, 20],\n",
    "            'lpc': [20, 36]\n",
    "            }\n",
    "        frame_length = 36\n",
    "\n",
    "    elif version == 1:\n",
    "        layout = {\n",
    "            'cepstrum': [0,18],\n",
    "            'periods': [36, 37],\n",
    "            'pitch_corr': [37, 38],\n",
    "            'lpc': [39, 55],\n",
    "            }\n",
    "        frame_length = 55\n",
    "    else:\n",
    "        raise ValueError(f'unknown feature version: {version}')\n",
    "\n",
    "\n",
    "    raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))\n",
    "    raw_features = raw_features.reshape((-1, frame_length))\n",
    "\n",
    "    features = torch.cat(\n",
    "        [\n",
    "            raw_features[:, layout['cepstrum'][0]   : layout['cepstrum'][1]],\n",
    "            raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]\n",
    "        ],\n",
    "        dim=1\n",
    "    )\n",
    "\n",
    "    lpcs = raw_features[:, layout['lpc'][0]   : layout['lpc'][1]]\n",
    "    if version < 3:\n",
    "        periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()\n",
    "    else:\n",
    "        periods = torch.round(torch.clip(256./2**(raw_features[:, layout['periods'][0] : layout['periods'][1]] + 1.5), 32, 256)).long()\n",
    "\n",
    "    return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_lpc(signal, lpcs, frame_length=160):\n",
    "    num_frames, lpc_order = lpcs.shape\n",
    "\n",
    "    prediction = np.concatenate(\n",
    "        [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]\n",
    "    )\n",
    "    error = signal[lpc_order :] - prediction\n",
    "\n",
    "    return prediction, error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lpcnet_features = load_lpcnet_features('lp/features.f32')\n",
    "\n",
    "features = lpcnet_features['features'].numpy()\n",
    "periods = lpcnet_features['periods'].squeeze(-1).numpy()\n",
    "lpcs = lpcnet_features['lpcs'].numpy()\n",
    "\n",
    "x = np.fromfile('data/a3_short.pcm', dtype=np.int16).astype(np.float32) / 2**15\n",
    "x = np.concatenate((np.zeros(80), x, np.zeros(320)))\n",
    "x_preemph = x.copy()\n",
    "x_preemph[1:] -= 0.85 * x_preemph[:-1]\n",
    "\n",
    "num_frames = features.shape[0]\n",
    "x = x[:160 * num_frames]\n",
    "x_preemph = x_preemph[:160 * num_frames]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# short-term prediction\n",
    "pred, error = run_lpc(np.concatenate((np.zeros(16), x_preemph)), lpcs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# long-term prediction\n",
    "offset = 256\n",
    "padded_error = np.concatenate((np.zeros(offset), error))\n",
    "ltp_error = padded_error.copy()\n",
    "for i, p in enumerate(list(periods)):\n",
    "    t0 = i * 160 + offset\n",
    "    t1 = t0 + 160\n",
    "    \n",
    "    past = padded_error[t0 - p : t1 - p]\n",
    "    current = padded_error[t0 : t1]\n",
    "    \n",
    "    gain = np.dot(past, current) / (np.dot(past, past) + 1e-6)\n",
    "    ltp_error[t0 : t1] -= gain * past\n",
    "    \n",
    "    \n",
    "ltp_error = ltp_error[offset:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {