ref: e12c7f584a542ec6f2b00be67bf33165e182b24b
dir: /dnn/torch/osce/stndrd/presentation/linear_prediction.ipynb/
sed: Output line too long
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.animation\n",
"from scipy.io import wavfile\n",
"import scipy.signal\n",
"import torch\n",
"\n",
"from playback import make_playback_animation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def load_lpcnet_features(feature_file, version=2):\n",
" if version == 2 or version == 3:\n",
" layout = {\n",
" 'cepstrum': [0,18],\n",
" 'periods': [18, 19],\n",
" 'pitch_corr': [19, 20],\n",
" 'lpc': [20, 36]\n",
" }\n",
" frame_length = 36\n",
"\n",
" elif version == 1:\n",
" layout = {\n",
" 'cepstrum': [0,18],\n",
" 'periods': [36, 37],\n",
" 'pitch_corr': [37, 38],\n",
" 'lpc': [39, 55],\n",
" }\n",
" frame_length = 55\n",
" else:\n",
" raise ValueError(f'unknown feature version: {version}')\n",
"\n",
"\n",
" raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))\n",
" raw_features = raw_features.reshape((-1, frame_length))\n",
"\n",
" features = torch.cat(\n",
" [\n",
" raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]],\n",
" raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]\n",
" ],\n",
" dim=1\n",
" )\n",
"\n",
" lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]]\n",
" if version < 3:\n",
" periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()\n",
" else:\n",
" periods = torch.round(torch.clip(256./2**(raw_features[:, layout['periods'][0] : layout['periods'][1]] + 1.5), 32, 256)).long()\n",
"\n",
" return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def run_lpc(signal, lpcs, frame_length=160):\n",
" num_frames, lpc_order = lpcs.shape\n",
"\n",
" prediction = np.concatenate(\n",
" [- np.convolve(signal[i * frame_length : (i + 1) * frame_length + lpc_order - 1], lpcs[i], mode='valid') for i in range(num_frames)]\n",
" )\n",
" error = signal[lpc_order :] - prediction\n",
"\n",
" return prediction, error"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"lpcnet_features = load_lpcnet_features('lp/features.f32')\n",
"\n",
"features = lpcnet_features['features'].numpy()\n",
"periods = lpcnet_features['periods'].squeeze(-1).numpy()\n",
"lpcs = lpcnet_features['lpcs'].numpy()\n",
"\n",
"x = np.fromfile('data/a3_short.pcm', dtype=np.int16).astype(np.float32) / 2**15\n",
"x = np.concatenate((np.zeros(80), x, np.zeros(320)))\n",
"x_preemph = x.copy()\n",
"x_preemph[1:] -= 0.85 * x_preemph[:-1]\n",
"\n",
"num_frames = features.shape[0]\n",
"x = x[:160 * num_frames]\n",
"x_preemph = x_preemph[:160 * num_frames]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# short-term prediction\n",
"pred, error = run_lpc(np.concatenate((np.zeros(16), x_preemph)), lpcs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# long-term prediction\n",
"offset = 256\n",
"padded_error = np.concatenate((np.zeros(offset), error))\n",
"ltp_error = padded_error.copy()\n",
"for i, p in enumerate(list(periods)):\n",
" t0 = i * 160 + offset\n",
" t1 = t0 + 160\n",
" \n",
" past = padded_error[t0 - p : t1 - p]\n",
" current = padded_error[t0 : t1]\n",
" \n",
" gain = np.dot(past, current) / (np.dot(past, past) + 1e-6)\n",
" ltp_error[t0 : t1] -= gain * past\n",
" \n",
" \n",
"ltp_error = ltp_error[offset:]\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {