ref: 4f4b6242099998d7acf89e17c287dc7f605af607
dir: /dnn/torch/rdovae/fec_encoder.py/
""" /* Copyright (c) 2022 Amazon Written by Jan Buethe and Jean-Marc Valin */ /* Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ """ import os import subprocess import argparse os.environ['CUDA_VISIBLE_DEVICES'] = "" parser = argparse.ArgumentParser(description='Encode redundancy for Opus neural FEC. Designed for use with voip application and 20ms frames') parser.add_argument('input', metavar='<input signal>', help='audio input (.wav or .raw or .pcm as int16)') parser.add_argument('checkpoint', metavar='<weights>', help='model checkpoint') parser.add_argument('q0', metavar='<quant level 0>', type=int, help='quantization level for most recent frame') parser.add_argument('q1', metavar='<quant level 1>', type=int, help='quantization level for oldest frame') parser.add_argument('output', type=str, help='output file (will be extended with .fec)') parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)') parser.add_argument('--num-redundancy-frames', default=52, type=int, help='number of redundancy frames per packet (default 52)') parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)") parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)') parser.add_argument('--debug-output', action='store_true', help='if set, differently assembled features are written to disk') args = parser.parse_args() import numpy as np from scipy.io import wavfile import torch from rdovae import RDOVAE from packets import write_fec_packets torch.set_num_threads(4) checkpoint = torch.load(args.checkpoint, map_location="cpu") model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) model.load_state_dict(checkpoint['state_dict'], strict=False) model.to("cpu") lpc_order = 16 ## prepare input signal # SILK frame size is 20ms and LPCNet subframes are 10ms subframe_size = 160 frame_size = 2 * subframe_size # 91 samples delay to align with SILK decoded frames silk_delay = 91 # prepend zeros to have enough history to produce the first package zero_history = (args.num_redundancy_frames - 1) * frame_size # dump data has a (feature) delay of 10ms dump_data_delay = 160 total_delay = silk_delay + zero_history + args.extra_delay - dump_data_delay # load signal if args.input.endswith('.raw') or args.input.endswith('.pcm'): signal = np.fromfile(args.input, dtype='int16') elif args.input.endswith('.wav'): fs, signal = wavfile.read(args.input) else: raise ValueError(f'unknown input signal format: {args.input}') # fill up last frame with zeros padded_signal_length = len(signal) + total_delay tail = padded_signal_length % frame_size right_padding = (frame_size - tail) % frame_size signal = np.concatenate((np.zeros(total_delay, dtype=np.int16), signal, np.zeros(right_padding, dtype=np.int16))) padded_signal_file = os.path.splitext(args.input)[0] + '_padded.raw' signal.tofile(padded_signal_file) # write signal and call dump_data to create features feature_file = os.path.splitext(args.input)[0] + '_features.f32' command = f"{args.dump_data} -test {padded_signal_file} {feature_file}" r = subprocess.run(command, shell=True) if r.returncode != 0: raise RuntimeError(f"command '{command}' failed with exit code {r.returncode}") # load features nb_features = model.feature_dim + lpc_order nb_used_features = model.feature_dim # load features features = np.fromfile(feature_file, dtype='float32') num_subframes = len(features) // nb_features num_subframes = 2 * (num_subframes // 2) num_frames = num_subframes // 2 features = np.reshape(features, (1, -1, nb_features)) features = features[:, :, :nb_used_features] features = features[:, :num_subframes, :] # quant_ids in reverse decoding order quant_ids = torch.round((args.q1 + (args.q0 - args.q1) * torch.arange(args.num_redundancy_frames // 2) / (args.num_redundancy_frames // 2 - 1))).long() print(f"using quantization levels {quant_ids}...") # convert input to torch tensors features = torch.from_numpy(features) # run encoder print("running fec encoder...") with torch.no_grad(): # encoding z, states, state_size = model.encode(features) # decoder on packet chunks input_length = args.num_redundancy_frames // 2 offset = args.num_redundancy_frames - 1 packets = [] packet_sizes = [] for i in range(offset, num_frames): print(f"processing frame {i - offset}...") # quantize / unquantize latent vectors zi = torch.clone(z[:, i - 2 * input_length + 2: i + 1 : 2, :]) zi, rates = model.quantize(zi, quant_ids) zi = model.unquantize(zi, quant_ids) features = model.decode(zi, states[:, i : i + 1, :]) packets.append(features.squeeze(0).numpy()) packet_size = 8 * int((torch.sum(rates) + 7 + state_size) / 8) packet_sizes.append(packet_size) # write packets packet_file = args.output + '.fec' if not args.output.endswith('.fec') else args.output write_fec_packets(packet_file, packets, packet_sizes) print(f"average redundancy rate: {int(round(sum(packet_sizes) / len(packet_sizes) * 50 / 1000))} kbps") # assemble features according to loss file if args.lossfile != None: num_packets = len(packets) loss = np.loadtxt(args.lossfile, dtype='int16') fec_out = np.zeros((num_packets * 2, packets[0].shape[-1]), dtype='float32') foffset = -2 ptr = 0 count = 2 for i in range(num_packets): if (loss[i] == 0) or (i == num_packets - 1): fec_out[ptr:ptr+count,:] = packets[i][foffset:, :] ptr += count foffset = -2 count = 2 else: count += 2 foffset -= 2 fec_out_full = np.zeros((fec_out.shape[0], 36), dtype=np.float32) fec_out_full[:, : fec_out.shape[-1]] = fec_out fec_out_full.tofile(packet_file[:-4] + f'_fec.f32') if args.debug_output: import itertools batches = [4] offsets = [0, 2 * args.num_redundancy_frames - 4] # sanity checks # 1. concatenate features at offset 0 for batch, offset in itertools.product(batches, offsets): stop = packets[0].shape[1] - offset test_features = np.concatenate([packet[stop - batch: stop, :] for packet in packets[::batch//2]], axis=0) test_features_full = np.zeros((test_features.shape[0], nb_features), dtype=np.float32) test_features_full[:, :nb_used_features] = test_features[:, :] print(f"writing debug output {packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32'}") test_features_full.tofile(packet_file[:-4] + f'_torch_batch{batch}_offset{offset}.f32')